Merge branch 'main' into fix--postgres-impl
This commit is contained in:
@@ -156,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
|
@@ -171,7 +171,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
|
@@ -101,7 +101,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
|
@@ -938,7 +938,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
return list_data
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func([query])
|
||||
|
@@ -120,7 +120,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
|
@@ -553,18 +553,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error during upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
)
|
||||
),
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@@ -666,14 +654,14 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH n, count(r) AS degree
|
||||
WITH n, COALESCE(count(r), 0) AS degree
|
||||
WHERE degree >= $min_degree
|
||||
ORDER BY degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: n}) AS filtered_nodes
|
||||
UNWIND filtered_nodes AS node_info
|
||||
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
OPTIONAL MATCH (a)-[r]-(b)
|
||||
WHERE a IN kept_nodes AND b IN kept_nodes
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
@@ -703,7 +691,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
WITH start, nodes, relationships
|
||||
UNWIND nodes AS node
|
||||
OPTIONAL MATCH (node)-[r]-()
|
||||
WITH node, count(r) AS degree, start, nodes, relationships
|
||||
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
|
||||
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
|
||||
ORDER BY
|
||||
CASE
|
||||
@@ -716,7 +704,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
WITH collect({node: node}) AS filtered_nodes
|
||||
UNWIND filtered_nodes AS node_info
|
||||
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
OPTIONAL MATCH (a)-[r]-(b)
|
||||
WHERE a IN kept_nodes AND b IN kept_nodes
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
@@ -744,11 +732,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[
|
||||
label
|
||||
for label in node.labels
|
||||
if label != "base"
|
||||
],
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
@@ -865,9 +849,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=[
|
||||
label for label in b_node.labels if label != "base"
|
||||
],
|
||||
labels=list(f"{target_id}"),
|
||||
properties=dict(b_node.properties),
|
||||
)
|
||||
|
||||
@@ -907,9 +889,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=f"{node_record['n'].get('entity_id')}",
|
||||
labels=[
|
||||
label for label in node_record["n"].labels if label != "base"
|
||||
],
|
||||
labels=list(f"{node_record['n'].get('entity_id')}"),
|
||||
properties=dict(node_record["n"].properties),
|
||||
)
|
||||
finally:
|
||||
|
@@ -417,7 +417,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
self.db = None
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
# 转换精度
|
||||
|
@@ -123,7 +123,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
|
@@ -306,7 +306,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search from tidb vector"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
|
Reference in New Issue
Block a user