Merge branch 'main' into main
This commit is contained in:
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
"MongoKVStorage",
|
||||
# "TiDBKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
# "AGEStorage",
|
||||
# "MongoGraphStorage",
|
||||
# "TiDBGraphStorage",
|
||||
# "GremlinStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
# "TiDBVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
@@ -112,9 +90,6 @@ STORAGES = {
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
@@ -122,14 +97,14 @@ STORAGES = {
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
# "TiDBKVStorage": ".kg.tidb_impl",
|
||||
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
# "TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
# "GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
|
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
import psycopg # type: ignore
|
||||
from psycopg.rows import namedtuple_row # type: ignore
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
|
||||
async def index_done_callback(self) -> None:
|
||||
# AGES handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all nodes and relationships in the graph.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
DETACH DELETE n
|
||||
"""
|
||||
await self._query(query)
|
||||
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
@@ -10,8 +11,8 @@ import pipmaster as pm
|
||||
if not pm.is_installed("chromadb"):
|
||||
pm.install("chromadb")
|
||||
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from chromadb import HttpClient, PersistentClient # type: ignore
|
||||
from chromadb.config import Settings # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all documents from the ChromaDB collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Get all IDs in the collection
|
||||
result = self._collection.get(include=[])
|
||||
if result and result["ids"] and len(result["ids"]) > 0:
|
||||
# Delete all documents
|
||||
self._collection.delete(ids=result["ids"])
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -11,16 +11,20 @@ import pipmaster as pm
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss # type: ignore
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
)
|
||||
|
||||
import faiss # type: ignore
|
||||
|
||||
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
||||
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
||||
|
||||
if not pm.is_installed(FAISS_PACKAGE):
|
||||
pm.install(FAISS_PACKAGE)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
async def delete(self, ids: list[str]):
|
||||
"""
|
||||
Delete vectors for the provided custom IDs.
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
to_remove = []
|
||||
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||
await self.delete([entity_id])
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Delete relations for a given entity by scanning metadata.
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.debug(f"Searching relations for entity {entity_name}")
|
||||
relations = []
|
||||
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
results.append({**metadata, "id": metadata.get("__id__")})
|
||||
|
||||
return results
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method will remove all vectors from the Faiss index and delete the storage files.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# Reset the index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Remove storage files if they exist
|
||||
if os.path.exists(self._faiss_index_file):
|
||||
os.remove(self._faiss_index_file)
|
||||
if os.path.exists(self._meta_file):
|
||||
os.remove(self._meta_file)
|
||||
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
|
||||
# Notify other processes
|
||||
await set_all_update_flags(self.namespace)
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
|
||||
if not pm.is_installed("gremlinpython"):
|
||||
pm.install("gremlinpython")
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
from gremlin_python.driver import client, serializer # type: ignore
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
|
||||
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge deletion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all nodes and relationships in the graph.
|
||||
|
||||
This function deletes all nodes with the specified graph name property,
|
||||
which automatically removes all associated edges.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.drop()
|
||||
"""
|
||||
await self._query(query)
|
||||
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
async with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
async def delete(self, doc_ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
any_deleted = False
|
||||
for doc_id in doc_ids:
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all document status data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Clear all document status data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
any_deleted = False
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of cache mode to be drop from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.delete(modes)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
This action will persistent the data to disk immediately.
|
||||
|
||||
This method will:
|
||||
1. Clear all data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
|
||||
import configparser
|
||||
from pymilvus import MilvusClient
|
||||
from pymilvus import MilvusClient # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all data from the Milvus collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Drop the collection and recreate it
|
||||
if self._client.has_collection(self.namespace):
|
||||
self._client.drop_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from motor.motor_asyncio import (
|
||||
from motor.motor_asyncio import ( # type: ignore
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
from pymongo.operations import SearchIndexModel # type: ignore
|
||||
from pymongo.errors import PyMongoError # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete documents with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of document IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self._data.delete_many({"_id": {"$in": ids}})
|
||||
logger.info(
|
||||
f"Deleted {result.deleted_count} documents from {self.namespace}"
|
||||
)
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting documents from {self.namespace}: {e}")
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build regex pattern to match documents with the specified modes
|
||||
pattern = f"^({'|'.join(modes)})_"
|
||||
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
||||
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
logger.debug(f"Successfully deleted edges: {edges}")
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
# Delete all documents
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
# Recreate vector index
|
||||
await self.create_vector_index_if_not_exists()
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped and vector index recreated",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
client = await self._get_client()
|
||||
return client.get(ids)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method is intended for use in scenarios where all data needs to be removed,
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._client_file_name):
|
||||
os.remove(self._client_file_name)
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final, Optional
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||
USERNAME = os.environ.get(
|
||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
||||
)
|
||||
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
||||
)
|
||||
|
||||
# Try to connect to the database
|
||||
with GraphDatabase.driver(
|
||||
URI,
|
||||
auth=(USERNAME, PASSWORD),
|
||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||
connection_timeout=CONNECTION_TIMEOUT,
|
||||
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||
) as _sync_driver:
|
||||
for database in (DATABASE, None):
|
||||
self._DATABASE = database
|
||||
connected = False
|
||||
# Try to connect to the database and create it if it doesn't exist
|
||||
for database in (DATABASE, None):
|
||||
self._DATABASE = database
|
||||
connected = False
|
||||
|
||||
try:
|
||||
with _sync_driver.session(database=database) as session:
|
||||
try:
|
||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
logger.info(f"Connected to {database} at {URI}")
|
||||
connected = True
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(
|
||||
f"{database} at {URI} is not available".capitalize()
|
||||
)
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {database} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||
logger.info(
|
||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||
try:
|
||||
async with self._driver.session(database=database) as session:
|
||||
try:
|
||||
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"Connected to {database} at {URI}")
|
||||
connected = True
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(
|
||||
f"{database} at {URI} is not available".capitalize()
|
||||
)
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(
|
||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {database} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||
logger.info(
|
||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||
)
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
result = await session.run(
|
||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||
)
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"{database} at {URI} created".capitalize())
|
||||
connected = True
|
||||
except (
|
||||
neo4jExceptions.ClientError,
|
||||
neo4jExceptions.DatabaseError,
|
||||
) as e:
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
||||
if database is not None:
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||
)
|
||||
logger.info(f"{database} at {URI} created".capitalize())
|
||||
connected = True
|
||||
except (
|
||||
neo4jExceptions.ClientError,
|
||||
neo4jExceptions.DatabaseError,
|
||||
) as e:
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
) or (
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
):
|
||||
if database is not None:
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||
)
|
||||
if database is None:
|
||||
logger.error(f"Failed to create {database} at {URI}")
|
||||
raise e
|
||||
if database is None:
|
||||
logger.error(f"Failed to create {database} at {URI}")
|
||||
raise e
|
||||
|
||||
if connected:
|
||||
break
|
||||
if connected:
|
||||
# Create index for base nodes on entity_id if it doesn't exist
|
||||
try:
|
||||
async with self._driver.session(database=database) as session:
|
||||
# Check if index exists first
|
||||
check_query = """
|
||||
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
||||
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
|
||||
RETURN count(*) > 0 AS exists
|
||||
"""
|
||||
try:
|
||||
check_result = await session.run(check_query)
|
||||
record = await check_result.single()
|
||||
await check_result.consume()
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
index_exists = record and record.get("exists", False)
|
||||
|
||||
async def close(self):
|
||||
if not index_exists:
|
||||
# Create index only if it doesn't exist
|
||||
result = await session.run(
|
||||
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
logger.info(
|
||||
f"Created index for base nodes on entity_id in {database}"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if db.indexes() is not supported in this Neo4j version
|
||||
result = await session.run(
|
||||
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create index: {str(e)}")
|
||||
break
|
||||
|
||||
async def finalize(self):
|
||||
"""Close the Neo4j driver and release all resources"""
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Ensure driver is closed when context manager exits"""
|
||||
await self.close()
|
||||
await self.finalize()
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Noe4J handles persistence automatically
|
||||
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
raise
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier.
|
||||
"""Get node by its label identifier, return only node properties
|
||||
|
||||
Args:
|
||||
node_id: The node label to look up
|
||||
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
||||
)
|
||||
# Return default edge properties when no edge found
|
||||
return {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}
|
||||
# Return None when no edge found
|
||||
return None
|
||||
finally:
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"""
|
||||
properties = node_data
|
||||
entity_type = properties["entity_type"]
|
||||
entity_id = properties["entity_id"]
|
||||
if "entity_id" not in properties:
|
||||
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
||||
|
||||
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = (
|
||||
"""
|
||||
MERGE (n:base {entity_id: $properties.entity_id})
|
||||
MERGE (n:base {entity_id: $entity_id})
|
||||
SET n += $properties
|
||||
SET n:`%s`
|
||||
"""
|
||||
% entity_type
|
||||
)
|
||||
result = await tx.run(query, properties=properties)
|
||||
result = await tx.run(
|
||||
query, entity_id=node_id, properties=properties
|
||||
)
|
||||
logger.debug(
|
||||
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
|
||||
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Complete connected subgraph for specified node
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
) as session:
|
||||
try:
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph is truncated
|
||||
count_query = "MATCH (n) RETURN count(n) as total"
|
||||
count_result = None
|
||||
try:
|
||||
count_result = await session.run(count_query)
|
||||
count_record = await count_result.single()
|
||||
|
||||
if count_record and count_record["total"] > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
finally:
|
||||
if count_result:
|
||||
await count_result.consume()
|
||||
|
||||
# Run main query to get nodes with highest degree
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH n, COALESCE(count(r), 0) AS degree
|
||||
WHERE degree >= $min_degree
|
||||
ORDER BY degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: n}) AS filtered_nodes
|
||||
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
"""
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
|
||||
)
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{"max_nodes": max_nodes},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
|
||||
else:
|
||||
# Main query uses partial matching
|
||||
main_query = """
|
||||
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
# First try without limit to check if we need to truncate
|
||||
full_query = """
|
||||
MATCH (start)
|
||||
WHERE
|
||||
CASE
|
||||
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
|
||||
ELSE start.entity_id = $entity_id
|
||||
END
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
WITH start, nodes, relationships
|
||||
WITH nodes, relationships, size(nodes) AS total_nodes
|
||||
UNWIND nodes AS node
|
||||
OPTIONAL MATCH (node)-[r]-()
|
||||
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
|
||||
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
|
||||
ORDER BY
|
||||
CASE
|
||||
WHEN node = start THEN 3
|
||||
WHEN EXISTS((start)--(node)) THEN 2
|
||||
ELSE 1
|
||||
END DESC,
|
||||
degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: node}) AS filtered_nodes
|
||||
UNWIND filtered_nodes AS node_info
|
||||
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
||||
OPTIONAL MATCH (a)-[r]-(b)
|
||||
WHERE a IN kept_nodes AND b IN kept_nodes
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
||||
RETURN node_info, relationships, total_nodes
|
||||
"""
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{
|
||||
"max_nodes": MAX_GRAPH_NODES,
|
||||
"entity_id": node_label,
|
||||
"inclusive": inclusive,
|
||||
"max_depth": max_depth,
|
||||
"min_degree": min_degree,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
record = await result_set.single()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node_info in record["node_info"]:
|
||||
node = node_info["node"]
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
||||
# Try to get full result
|
||||
full_result = None
|
||||
try:
|
||||
full_result = await session.run(
|
||||
full_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await result_set.consume() # Ensure result set is consumed
|
||||
full_record = await full_result.single()
|
||||
|
||||
# If no record found, return empty KnowledgeGraph
|
||||
if not full_record:
|
||||
logger.debug(f"No nodes found for entity_id: {node_label}")
|
||||
return result
|
||||
|
||||
# If record found, check node count
|
||||
total_nodes = full_record["total_nodes"]
|
||||
|
||||
if total_nodes <= max_nodes:
|
||||
# If node count is within limit, use full result directly
|
||||
logger.debug(
|
||||
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
||||
)
|
||||
record = full_record
|
||||
else:
|
||||
# If node count exceeds limit, set truncated flag and run limited query
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
||||
)
|
||||
|
||||
# Run limited query
|
||||
limited_query = """
|
||||
MATCH (start)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
minLevel: 0,
|
||||
maxLevel: $max_depth,
|
||||
limit: $max_nodes,
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
UNWIND nodes AS node
|
||||
WITH collect({node: node}) AS node_info, relationships
|
||||
RETURN node_info, relationships
|
||||
"""
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
limited_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
"max_nodes": max_nodes,
|
||||
},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
finally:
|
||||
if full_result:
|
||||
await full_result.consume()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node_info in record["node_info"]:
|
||||
node = node_info["node"]
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.warning(f"APOC plugin error: {str(e)}")
|
||||
@@ -767,110 +837,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.warning(
|
||||
"Neo4j: falling back to basic Cypher recursive search..."
|
||||
)
|
||||
if inclusive:
|
||||
logger.warning(
|
||||
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
||||
)
|
||||
return await self._robust_fallback(
|
||||
node_label, max_depth, min_degree
|
||||
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
else:
|
||||
logger.warning(
|
||||
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _robust_fallback(
|
||||
self, node_label: str, max_depth: int, min_degree: int = 0
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Fallback implementation when APOC plugin is not available or incompatible.
|
||||
This method implements the same functionality as get_knowledge_graph but uses
|
||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
||||
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
result = KnowledgeGraph()
|
||||
visited_nodes = set()
|
||||
visited_edges = set()
|
||||
|
||||
async def traverse(
|
||||
node: KnowledgeGraphNode,
|
||||
edge: Optional[KnowledgeGraphEdge],
|
||||
current_depth: int,
|
||||
):
|
||||
# Check traversal limits
|
||||
if current_depth > max_depth:
|
||||
logger.debug(f"Reached max depth: {max_depth}")
|
||||
return
|
||||
if len(visited_nodes) >= MAX_GRAPH_NODES:
|
||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
||||
return
|
||||
|
||||
# Check if node already visited
|
||||
if node.id in visited_nodes:
|
||||
return
|
||||
|
||||
# Get all edges and target nodes
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(
|
||||
1000
|
||||
) # Max neighbour nodes we can handled
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Nodes not connected to start node need to check degree
|
||||
if current_depth > 1 and len(records) < min_degree:
|
||||
return
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(node)
|
||||
visited_nodes.add(node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if edge and edge.id not in visited_edges:
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge.id)
|
||||
|
||||
# Prepare nodes and edges for recursive processing
|
||||
nodes_to_process = []
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=list(f"{target_id}"),
|
||||
properties=dict(b_node.properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
nodes_to_process.append((target_node, target_edge))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||
)
|
||||
|
||||
# Process nodes after releasing database connection
|
||||
for target_node, target_edge in nodes_to_process:
|
||||
await traverse(target_node, target_edge, current_depth + 1)
|
||||
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
||||
|
||||
# Get the starting node's data
|
||||
async with self._driver.session(
|
||||
@@ -889,15 +877,129 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=f"{node_record['n'].get('entity_id')}",
|
||||
labels=list(f"{node_record['n'].get('entity_id')}"),
|
||||
properties=dict(node_record["n"].properties),
|
||||
labels=[node_record["n"].get("entity_id")],
|
||||
properties=dict(node_record["n"]._properties),
|
||||
)
|
||||
finally:
|
||||
await node_result.consume() # Ensure results are consumed
|
||||
|
||||
# Start traversal with the initial node
|
||||
await traverse(start_node, None, 0)
|
||||
# Initialize queue for BFS with (node, edge, depth) tuples
|
||||
# edge is None for the starting node
|
||||
queue = deque([(start_node, None, 0)])
|
||||
|
||||
# True BFS implementation using a queue
|
||||
while queue and len(visited_nodes) < max_nodes:
|
||||
# Dequeue the next node to process
|
||||
current_node, current_edge, current_depth = queue.popleft()
|
||||
|
||||
# Skip if already visited or exceeds max depth
|
||||
if current_node.id in visited_nodes:
|
||||
continue
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(current_node)
|
||||
visited_nodes.add(current_node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if current_edge and current_edge.id not in visited_edges:
|
||||
result.edges.append(current_edge)
|
||||
visited_edges.add(current_edge.id)
|
||||
|
||||
# Stop if we've reached the node limit
|
||||
if len(visited_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||
)
|
||||
break
|
||||
|
||||
# Get all edges and target nodes for the current node (even at max_depth)
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=current_node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Process all neighbors - capture all edges but only queue unvisited nodes
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=[target_id],
|
||||
properties=dict(b_node._properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{current_node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
||||
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
||||
|
||||
# 检查是否已存在相同的边(考虑无向性)
|
||||
if sorted_pair not in visited_edge_pairs:
|
||||
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
||||
if target_id in visited_nodes or (
|
||||
target_id not in visited_nodes
|
||||
and current_depth < max_depth
|
||||
):
|
||||
result.edges.append(target_edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
|
||||
# Only add unvisited nodes to the queue for further expansion
|
||||
if target_id not in visited_nodes:
|
||||
# Only add to queue if we're not at max depth yet
|
||||
if current_depth < max_depth:
|
||||
# Add node to queue with incremented depth
|
||||
# Edge is already added to result, so we pass None as edge
|
||||
queue.append((target_node, None, current_depth + 1))
|
||||
else:
|
||||
# At max depth, we've already added the edge but we don't add the node
|
||||
# This prevents adding nodes beyond max_depth to the result
|
||||
logger.debug(
|
||||
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
||||
)
|
||||
else:
|
||||
# If target node already exists in result, we don't need to add it again
|
||||
logger.debug(
|
||||
f"Node {target_id} already visited, edge added but node not queued"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
# Method 2: Query compatible with older versions
|
||||
query = """
|
||||
MATCH (n)
|
||||
MATCH (n:base)
|
||||
WHERE n.entity_id IS NOT NULL
|
||||
RETURN DISTINCT n.entity_id AS label
|
||||
ORDER BY label
|
||||
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
|
||||
This method will delete all nodes and relationships in the Neo4j database.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
# Delete all nodes and relationships
|
||||
query = "MATCH (n) DETACH DELETE n"
|
||||
result = await session.run(query)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
# TODO:deprecated, remove later
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
# TODO: NOT USED
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
node_label: Label of the starting node,* means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Initialize sets for start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
result = KnowledgeGraph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = (
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
# Get degrees of all nodes
|
||||
degrees = dict(graph.degree())
|
||||
# Sort nodes by degree in descending order and take top max_nodes
|
||||
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Check if graph is truncated
|
||||
if len(sorted_nodes) > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
|
||||
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
|
||||
# Create subgraph with the highest degree nodes
|
||||
subgraph = graph.subgraph(limited_nodes)
|
||||
else:
|
||||
# Find nodes with matching node id based on search_mode
|
||||
nodes_to_explore = []
|
||||
for n, attr in graph.nodes(data=True):
|
||||
node_str = str(n)
|
||||
if not inclusive:
|
||||
if node_label == node_str: # Use exact matching
|
||||
nodes_to_explore.append(n)
|
||||
else: # inclusive mode
|
||||
if node_label in node_str: # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
# Check if node exists
|
||||
if node_label not in graph:
|
||||
logger.warning(f"Node {node_label} not found in the graph")
|
||||
return KnowledgeGraph() # Return empty graph
|
||||
|
||||
if not nodes_to_explore:
|
||||
logger.warning(f"No nodes found with label {node_label}")
|
||||
return result
|
||||
# Use BFS to get nodes
|
||||
bfs_nodes = []
|
||||
visited = set()
|
||||
queue = [(node_label, 0)] # (node, depth) tuple
|
||||
|
||||
# Get subgraph using ego_graph from all matching nodes
|
||||
combined_subgraph = nx.Graph()
|
||||
for start_node in nodes_to_explore:
|
||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
# Breadth-first search
|
||||
while queue and len(bfs_nodes) < max_nodes:
|
||||
current, depth = queue.pop(0)
|
||||
if current not in visited:
|
||||
visited.add(current)
|
||||
bfs_nodes.append(current)
|
||||
|
||||
# Get start nodes and direct connected nodes
|
||||
if nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(
|
||||
combined_subgraph.neighbors(start_node)
|
||||
)
|
||||
# Only explore neighbors if we haven't reached max_depth
|
||||
if depth < max_depth:
|
||||
# Add neighbor nodes to queue with incremented depth
|
||||
neighbors = list(graph.neighbors(current))
|
||||
queue.extend(
|
||||
[(n, depth + 1) for n in neighbors if n not in visited]
|
||||
)
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
# Check if graph is truncated - if we still have nodes in the queue
|
||||
# and we've reached max_nodes, then the graph is truncated
|
||||
if queue and len(bfs_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
|
||||
)
|
||||
|
||||
subgraph = combined_subgraph
|
||||
|
||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||
if min_degree > 0:
|
||||
nodes_to_keep = [
|
||||
node
|
||||
for node, degree in subgraph.degree()
|
||||
if node in start_nodes
|
||||
or node in direct_connected_nodes
|
||||
or degree >= min_degree
|
||||
]
|
||||
subgraph = subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
node_degrees = dict(subgraph.degree())
|
||||
|
||||
def priority_key(node_item):
|
||||
node, degree = node_item
|
||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||
if node in start_nodes:
|
||||
priority = 2
|
||||
elif node in direct_connected_nodes:
|
||||
priority = 1
|
||||
else:
|
||||
priority = 0
|
||||
return (priority, degree)
|
||||
|
||||
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
|
||||
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
||||
:MAX_GRAPH_NODES
|
||||
]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph and keep nodes only with most degree
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
logger.info(
|
||||
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
|
||||
)
|
||||
# Create subgraph with BFS discovered nodes
|
||||
subgraph = graph.subgraph(bfs_nodes)
|
||||
|
||||
# Add nodes to result
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
for node in subgraph.nodes():
|
||||
if str(node) in seen_nodes:
|
||||
continue
|
||||
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for edge in subgraph.edges():
|
||||
source, target = edge
|
||||
# Esure unique edge_id for undirect graph
|
||||
if source > target:
|
||||
if str(source) > str(target):
|
||||
source, target = target, source
|
||||
edge_id = f"{source}-{target}"
|
||||
if edge_id in seen_edges:
|
||||
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return False # Return error
|
||||
|
||||
return True
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all graph data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the graph storage file if it exists
|
||||
2. Reset the graph to an empty state
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._graphml_xml_file):
|
||||
os.remove(self._graphml_xml_file)
|
||||
self._graph = nx.Graph()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -8,17 +8,15 @@ import uuid
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
import configparser
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("qdrant-client"):
|
||||
pm.install("qdrant-client")
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client import QdrantClient, models # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Convert to Qdrant compatible ID
|
||||
qdrant_id = compute_mdhash_id_for_qdrant(id)
|
||||
|
||||
# Retrieve the point by ID
|
||||
result = self._client.retrieve(
|
||||
collection_name=self.namespace,
|
||||
ids=[qdrant_id],
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0].payload
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Convert to Qdrant compatible IDs
|
||||
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
|
||||
|
||||
# Retrieve the points by IDs
|
||||
results = self._client.retrieve(
|
||||
collection_name=self.namespace,
|
||||
ids=qdrant_ids,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
return [point.payload for point in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all data from the Qdrant collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Delete the collection and recreate it
|
||||
if self._client.collection_exists(self.namespace):
|
||||
self._client.delete_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
vectors_config=models.VectorParams(
|
||||
size=self.embedding_func.embedding_dim,
|
||||
distance=models.Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
|
||||
from redis.asyncio import Redis, ConnectionPool
|
||||
from redis.exceptions import RedisError, ConnectionError
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
|
||||
from lightrag.base import BaseKVStorage
|
||||
import json
|
||||
|
||||
@@ -121,7 +122,11 @@ class RedisKVStorage(BaseKVStorage):
|
||||
except json.JSONEncodeError as e:
|
||||
logger.error(f"JSON encode error during upsert: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete entries with specified IDs"""
|
||||
if not ids:
|
||||
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
|
||||
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete an entity by name"""
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
|
||||
Importance notes for Redis storage:
|
||||
1. This will immediately delete the specified cache modes from Redis
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache mode to be drop from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
await self.delete(modes)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async with self._get_redis_connection() as redis:
|
||||
result = await redis.delete(f"{self.namespace}:{entity_id}")
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all keys under the current namespace.
|
||||
|
||||
if result:
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
keys = await redis.keys(f"{self.namespace}:*")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete all relations associated with an entity"""
|
||||
try:
|
||||
async with self._get_redis_connection() as redis:
|
||||
cursor = 0
|
||||
relation_keys = []
|
||||
pattern = f"{self.namespace}:*"
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern)
|
||||
|
||||
# Process keys in batches
|
||||
if keys:
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if (
|
||||
data.get("src_id") == entity_name
|
||||
or data.get("tgt_id") == entity_name
|
||||
):
|
||||
relation_keys.append(key)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON in key {key}")
|
||||
continue
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count = sum(results)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Delete relations in batches
|
||||
if relation_keys:
|
||||
# Delete in chunks to avoid too many arguments
|
||||
chunk_size = 1000
|
||||
for i in range(0, len(relation_keys), chunk_size):
|
||||
chunk = relation_keys[i:i + chunk_size]
|
||||
deleted = await redis.delete(*chunk)
|
||||
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
|
||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||
return {"status": "success", "message": f"{deleted_count} keys dropped"}
|
||||
else:
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
logger.info(f"No keys found to drop in {self.namespace}")
|
||||
return {"status": "success", "message": "no keys to drop"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy import create_engine, text # type: ignore
|
||||
|
||||
|
||||
class TiDB:
|
||||
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete records with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of record IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
id_field = namespace_to_id(self.namespace)
|
||||
|
||||
if not table_name or not id_field:
|
||||
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
||||
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.info(
|
||||
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting records from {self.namespace}: {e}")
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return False
|
||||
|
||||
if table_name != "LIGHTRAG_LLM_CACHE":
|
||||
return False
|
||||
|
||||
# 构建MySQL风格的IN查询
|
||||
modes_list = ", ".join([f"'{mode}'" for mode in modes])
|
||||
sql = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE workspace = :workspace
|
||||
AND mode IN ({modes_list})
|
||||
"""
|
||||
|
||||
logger.info(f"Deleting cache by modes: {modes}")
|
||||
await self.db.execute(sql, {"workspace": self.db.workspace})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown namespace: {self.namespace}",
|
||||
}
|
||||
|
||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||
table_name=table_name
|
||||
)
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete vectors with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
id_field = namespace_to_id(self.namespace)
|
||||
|
||||
if not table_name or not id_field:
|
||||
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
||||
|
||||
try:
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete an entity by its name from the vector storage.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity to delete
|
||||
"""
|
||||
try:
|
||||
# Construct SQL to delete the entity
|
||||
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE workspace = :workspace AND name = :entity_name"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete all relations associated with an entity.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity whose relations should be deleted
|
||||
"""
|
||||
try:
|
||||
# Delete relations where the entity is either the source or target
|
||||
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted relations for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown namespace: {self.namespace}",
|
||||
}
|
||||
|
||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||
table_name=table_name
|
||||
)
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
drop_sql = """
|
||||
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
|
||||
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
|
||||
"""
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node and all its related edges
|
||||
|
||||
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
""",
|
||||
# Drop tables
|
||||
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
|
||||
}
|
||||
|
Reference in New Issue
Block a user