cleaned code

This commit is contained in:
Yannick Stephan
2025-02-16 13:55:30 +01:00
parent 882190a515
commit 931c31fa8c
14 changed files with 93 additions and 59 deletions

View File

@@ -92,6 +92,7 @@ class StorageNameSpace:
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results."""
raise NotImplementedError
@@ -165,7 +166,6 @@ class BaseGraphStorage(StorageNameSpace):
"""Get all edges connected to a node."""
raise NotImplementedError
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Upsert a node into the graph."""
raise NotImplementedError
@@ -194,7 +194,9 @@ class BaseGraphStorage(StorageNameSpace):
"""Get a knowledge graph of a node."""
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
raise NotImplementedError

View File

@@ -615,11 +615,15 @@ class AGEStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -183,7 +183,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB query: {str(e)}")
raise
async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically
pass

View File

@@ -398,5 +398,7 @@ class GremlinStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -39,7 +39,7 @@ class JsonKVStorage(BaseKVStorage):
]
async def filter_keys(self, keys: set[str]) -> set[str]:
return set(data) - set(self._data.keys())
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data}

View File

@@ -68,9 +68,9 @@ class MongoKVStorage(BaseKVStorage):
return await cursor.to_list()
async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids
return keys - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@@ -570,7 +570,9 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
"""
Placeholder for demonstration, raises NotImplementedError.
"""
@@ -600,7 +602,9 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"])
return labels
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
@@ -927,6 +931,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it."""
client = MongoClient(uri)

View File

@@ -254,7 +254,6 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
try:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
@@ -436,7 +435,9 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self):
print("Implemented but never called.")
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)

View File

@@ -186,7 +186,9 @@ class NetworkXStorage(BaseGraphStorage):
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -226,5 +228,7 @@ class NetworkXStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -318,6 +318,7 @@ class OracleKVStorage(BaseKVStorage):
async def drop(self) -> None:
raise NotImplementedError
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
@@ -368,6 +369,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class OracleGraphStorage(BaseGraphStorage):
# db instance must be injected before use
@@ -452,7 +454,9 @@ class OracleGraphStorage(BaseGraphStorage):
await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -597,9 +601,12 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",

View File

@@ -299,6 +299,7 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> None:
raise NotImplementedError
@dataclass
class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
@@ -428,6 +429,7 @@ class PGVectorStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class PGDocStatusStorage(DocStatusStorage):
# db instance must be injected before use
@@ -1049,9 +1051,12 @@ class PGGraphStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",

View File

@@ -41,12 +41,12 @@ class RedisKVStorage(BaseKVStorage):
async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline()
for key in data:
for key in keys:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
return set(data) - existing_ids
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pipe = self._redis.pipeline()

View File

@@ -127,7 +127,6 @@ class TiDBKVStorage(BaseKVStorage):
return await self.db.query(SQL, multirows=True)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
id_field=namespace_to_id(self.namespace),
@@ -211,6 +210,7 @@ class TiDBKVStorage(BaseKVStorage):
async def drop(self) -> None:
raise NotImplementedError
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
@@ -335,7 +335,6 @@ class TiDBVectorDBStorage(BaseVectorStorage):
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
@@ -344,6 +343,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
@@ -420,7 +420,9 @@ class TiDBGraphStorage(BaseGraphStorage):
}
await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -485,9 +487,12 @@ class TiDBGraphStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",