Merge pull request #1032 from ArindamRoy23/main
Filter by ID during Query for Postgres VDB
This commit is contained in:
@@ -81,6 +81,9 @@ class QueryParam:
|
||||
history_turns: int = 3
|
||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||
|
||||
ids: list[str] | None = None
|
||||
"""List of ids to filter the results."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageNameSpace(ABC):
|
||||
@@ -107,7 +110,9 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query the vector storage and retrieve top_k results."""
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -438,6 +438,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
"entity_name": item["entity_name"],
|
||||
"content": item["content"],
|
||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||
"chunk_id": item["source_id"],
|
||||
# TODO: add document_id
|
||||
}
|
||||
return upsert_sql, data
|
||||
|
||||
@@ -450,6 +452,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
"target_id": item["tgt_id"],
|
||||
"content": item["content"],
|
||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||
"chunk_id": item["source_id"],
|
||||
# TODO: add document_id
|
||||
}
|
||||
return upsert_sql, data
|
||||
|
||||
@@ -492,13 +496,20 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
|
||||
if ids:
|
||||
formatted_ids = ",".join(f"'{id}'" for id in ids)
|
||||
else:
|
||||
formatted_ids = "NULL"
|
||||
|
||||
sql = SQL_TEMPLATES[self.base_namespace].format(
|
||||
embedding_string=embedding_string
|
||||
embedding_string=embedding_string, doc_ids=formatted_ids
|
||||
)
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
@@ -1491,6 +1502,7 @@ TABLES = {
|
||||
content_vector VECTOR,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
chunk_id VARCHAR(255) NULL,
|
||||
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
@@ -1504,6 +1516,7 @@ TABLES = {
|
||||
content_vector VECTOR,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
chunk_id VARCHAR(255) NULL,
|
||||
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
@@ -1586,8 +1599,9 @@ SQL_TEMPLATES = {
|
||||
content_vector=EXCLUDED.content_vector,
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||
content_vector, chunk_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET entity_name=EXCLUDED.entity_name,
|
||||
content=EXCLUDED.content,
|
||||
@@ -1595,8 +1609,8 @@ SQL_TEMPLATES = {
|
||||
update_time=CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
||||
target_id, content, content_vector)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
target_id, content, content_vector, chunk_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET source_id=EXCLUDED.source_id,
|
||||
target_id=EXCLUDED.target_id,
|
||||
@@ -1604,21 +1618,21 @@ SQL_TEMPLATES = {
|
||||
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
# SQL for VectorStorage
|
||||
"entities": """SELECT entity_name FROM
|
||||
(SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
""",
|
||||
"relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
||||
(SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
""",
|
||||
"chunks": """SELECT id FROM
|
||||
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
""",
|
||||
# "entities": """SELECT entity_name FROM
|
||||
# (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
# FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
||||
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
# """,
|
||||
# "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
||||
# (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
# FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
||||
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
# """,
|
||||
# "chunks": """SELECT id FROM
|
||||
# (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
# FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
||||
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
# """,
|
||||
# DROP tables
|
||||
"drop_all": """
|
||||
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
||||
@@ -1642,4 +1656,55 @@ SQL_TEMPLATES = {
|
||||
"drop_vdb_relation": """
|
||||
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
||||
""",
|
||||
"relationships": """
|
||||
WITH relevant_chunks AS (
|
||||
SELECT id as chunk_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
||||
)
|
||||
SELECT source_id as src_id, target_id as tgt_id
|
||||
FROM (
|
||||
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_VDB_RELATION r
|
||||
WHERE r.workspace=$1
|
||||
AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
||||
) filtered
|
||||
WHERE distance>$2
|
||||
ORDER BY distance DESC
|
||||
LIMIT $3
|
||||
""",
|
||||
"entities": """
|
||||
WITH relevant_chunks AS (
|
||||
SELECT id as chunk_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
||||
)
|
||||
SELECT entity_name FROM
|
||||
(
|
||||
SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_VDB_ENTITY
|
||||
where workspace=$1
|
||||
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
||||
)
|
||||
WHERE distance>$2
|
||||
ORDER BY distance DESC
|
||||
LIMIT $3
|
||||
""",
|
||||
"chunks": """
|
||||
WITH relevant_chunks AS (
|
||||
SELECT id as chunk_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
||||
)
|
||||
SELECT id FROM
|
||||
(
|
||||
SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
where workspace=$1
|
||||
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
||||
)
|
||||
WHERE distance>$2
|
||||
ORDER BY distance DESC
|
||||
LIMIT $3
|
||||
""",
|
||||
}
|
||||
|
@@ -962,7 +962,10 @@ async def mix_kg_vector_query(
|
||||
try:
|
||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||
mix_topk = min(10, query_param.top_k)
|
||||
results = await chunks_vdb.query(augmented_query, top_k=mix_topk)
|
||||
# TODO: add ids to the query
|
||||
results = await chunks_vdb.query(
|
||||
augmented_query, top_k=mix_topk, ids=query_param.ids
|
||||
)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
@@ -1171,7 +1174,11 @@ async def _get_node_data(
|
||||
logger.info(
|
||||
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||||
)
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
|
||||
results = await entities_vdb.query(
|
||||
query, top_k=query_param.top_k, ids=query_param.ids
|
||||
)
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
# get entity information
|
||||
@@ -1424,7 +1431,10 @@ async def _get_edge_data(
|
||||
logger.info(
|
||||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||
)
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
results = await relationships_vdb.query(
|
||||
keywords, top_k=query_param.top_k, ids=query_param.ids
|
||||
)
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
@@ -1673,7 +1683,9 @@ async def naive_query(
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=query_param.top_k, ids=query_param.ids
|
||||
)
|
||||
if not len(results):
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
|
Reference in New Issue
Block a user