Merge branch 'main' into fix--postgres-impl

This commit is contained in:
jofoks
2025-03-13 16:16:25 -07:00
24 changed files with 635 additions and 179 deletions

View File

@@ -156,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
try:
embedding = await self.embedding_func([query])

View File

@@ -171,7 +171,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""
Search by a textual query; returns top_k results with their metadata + similarity distance.
"""

View File

@@ -101,7 +101,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,

View File

@@ -938,7 +938,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search."""
# Generate the embedding
embedding = await self.embedding_func([query])

View File

@@ -120,7 +120,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query])
embedding = embedding[0]

View File

@@ -553,18 +553,6 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -666,14 +654,14 @@ class Neo4JStorage(BaseGraphStorage):
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) AS degree
WITH n, COALESCE(count(r), 0) AS degree
WHERE degree >= $min_degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
MATCH (a)-[r]-(b)
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
@@ -703,7 +691,7 @@ class Neo4JStorage(BaseGraphStorage):
WITH start, nodes, relationships
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, count(r) AS degree, start, nodes, relationships
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
ORDER BY
CASE
@@ -716,7 +704,7 @@ class Neo4JStorage(BaseGraphStorage):
WITH collect({node: node}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
MATCH (a)-[r]-(b)
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
@@ -744,11 +732,7 @@ class Neo4JStorage(BaseGraphStorage):
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[
label
for label in node.labels
if label != "base"
],
labels=[node.get("entity_id")],
properties=dict(node),
)
)
@@ -865,9 +849,7 @@ class Neo4JStorage(BaseGraphStorage):
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[
label for label in b_node.labels if label != "base"
],
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
)
@@ -907,9 +889,7 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=[
label for label in node_record["n"].labels if label != "base"
],
labels=list(f"{node_record['n'].get('entity_id')}"),
properties=dict(node_record["n"].properties),
)
finally:

View File

@@ -417,7 +417,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
self.db = None
#################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
# 转换精度

View File

@@ -123,7 +123,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,

View File

@@ -306,7 +306,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
await ClientManager.release_client(self.db)
self.db = None
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Search from tidb vector"""
embeddings = await self.embedding_func([query])
embedding = embeddings[0]