Refactor vector query methods to support optional ID filtering

- Updated BaseVectorStorage query method signature to accept optional IDs
- Modified operate.py to pass query parameter IDs to vector storage queries
- Updated PostgreSQL vector storage SQL templates to filter results by document IDs
- Removed unused parameters and simplified query logic across multiple files
This commit is contained in:
Roy
2025-03-08 15:43:17 +00:00
parent bbe139cfeb
commit 528fb11364
4 changed files with 85 additions and 34 deletions

View File

@@ -108,9 +108,8 @@ class BaseVectorStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
cosine_better_than_threshold: float = field(default=0.2) cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
@abstractmethod @abstractmethod
async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.""" """Query the vector storage and retrieve top_k results."""
@abstractmethod @abstractmethod

View File

@@ -439,6 +439,7 @@ class PGVectorStorage(BaseVectorStorage):
"content": item["content"], "content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()), "content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_id": item["source_id"], "chunk_id": item["source_id"],
#TODO: add document_id
} }
return upsert_sql, data return upsert_sql, data
@@ -452,6 +453,7 @@ class PGVectorStorage(BaseVectorStorage):
"content": item["content"], "content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()), "content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_id": item["source_id"] "chunk_id": item["source_id"]
#TODO: add document_id
} }
return upsert_sql, data return upsert_sql, data
@@ -494,13 +496,19 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
if ids:
formatted_ids = ",".join(f"'{id}'" for id in ids)
else:
formatted_ids = "NULL"
sql = SQL_TEMPLATES[self.base_namespace].format( sql = SQL_TEMPLATES[self.base_namespace].format(
embedding_string=embedding_string embedding_string=embedding_string,
doc_ids=formatted_ids
) )
params = { params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
@@ -1389,7 +1397,6 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP, update_time TIMESTAMP,
document_id VARCHAR(255) NULL,
chunk_id VARCHAR(255) NULL, chunk_id VARCHAR(255) NULL,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)""" )"""
@@ -1404,7 +1411,6 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP, update_time TIMESTAMP,
document_id VARCHAR(255) NULL,
chunk_id VARCHAR(255) NULL, chunk_id VARCHAR(255) NULL,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)""" )"""
@@ -1507,21 +1513,21 @@ SQL_TEMPLATES = {
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
""", """,
# SQL for VectorStorage # SQL for VectorStorage
"entities": """SELECT entity_name FROM # "entities": """SELECT entity_name FROM
(SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_VDB_ENTITY where workspace=$1) # FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
WHERE distance>$2 ORDER BY distance DESC LIMIT $3 # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
""", # """,
"relationships": """SELECT source_id as src_id, target_id as tgt_id FROM # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
(SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_VDB_RELATION where workspace=$1) # FROM LIGHTRAG_VDB_RELATION where workspace=$1)
WHERE distance>$2 ORDER BY distance DESC LIMIT $3 # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
""", # """,
"chunks": """SELECT id FROM # "chunks": """SELECT id FROM
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
WHERE distance>$2 ORDER BY distance DESC LIMIT $3 # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
""", # """,
# DROP tables # DROP tables
"drop_all": """ "drop_all": """
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE; DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
@@ -1545,4 +1551,56 @@ SQL_TEMPLATES = {
"drop_vdb_relation": """ "drop_vdb_relation": """
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE; DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
""", """,
} "relationships": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
)
SELECT source_id as src_id, target_id as tgt_id
FROM (
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_VDB_RELATION r
WHERE r.workspace=$1
AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
) filtered
WHERE distance>$2
ORDER BY distance DESC
LIMIT $3
""",
"entities":
'''
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
)
SELECT entity_name FROM
(
SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_VDB_ENTITY
where workspace=$1
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
)
WHERE distance>$2
ORDER BY distance DESC
LIMIT $3
''',
'chunks': """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
)
SELECT id FROM
(
SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_DOC_CHUNKS
where workspace=$1
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
)
WHERE distance>$2
ORDER BY distance DESC
LIMIT $3
"""
}

View File

@@ -1243,7 +1243,6 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
system_prompt=system_prompt, system_prompt=system_prompt,
ids = param.ids
) )
elif param.mode == "naive": elif param.mode == "naive":
response = await naive_query( response = await naive_query(

View File

@@ -602,7 +602,6 @@ async def kg_query(
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
ids: list[str] | None = None,
) -> str | AsyncIterator[str]: ) -> str | AsyncIterator[str]:
# Handle cache # Handle cache
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
@@ -650,7 +649,6 @@ async def kg_query(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
ids
) )
if query_param.only_need_context: if query_param.only_need_context:
@@ -1035,7 +1033,6 @@ async def _build_query_context(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
ids = ids
) )
else: # hybrid mode else: # hybrid mode
ll_data, hl_data = await asyncio.gather( ll_data, hl_data = await asyncio.gather(
@@ -1104,7 +1101,9 @@ async def _get_node_data(
logger.info( logger.info(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
) )
results = await entities_vdb.query(query, top_k=query_param.top_k)
results = await entities_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# get entity information # get entity information
@@ -1352,16 +1351,12 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
ids: list[str] | None = None,
): ):
logger.info( logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
) )
if ids:
#TODO: add ids to the query results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = query_param.ids)
results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = ids)
else:
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results): if not len(results):
return "", "", "" return "", "", ""
@@ -1610,7 +1605,7 @@ async def naive_query(
if cached_response is not None: if cached_response is not None:
return cached_response return cached_response
results = await chunks_vdb.query(query, top_k=query_param.top_k) results = await chunks_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
if not len(results): if not len(results):
return PROMPTS["fail_response"] return PROMPTS["fail_response"]