diff --git a/README.md b/README.md index 00da54fb..018a94e6 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,8 @@ class QueryParam: """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" max_token_for_local_context: int = 4000 """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + ids: list[str] | None = None # ONLY SUPPORTED FOR PG VECTOR DBs + """List of ids to filter the RAG.""" ... ``` diff --git a/lightrag/base.py b/lightrag/base.py index 4b840b37..c84c7c62 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -81,6 +81,9 @@ class QueryParam: history_turns: int = 3 """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" + ids: list[str] | None = None + """List of ids to filter the results.""" + @dataclass class StorageNameSpace(ABC): @@ -107,7 +110,9 @@ class BaseVectorStorage(StorageNameSpace, ABC): meta_fields: set[str] = field(default_factory=set) @abstractmethod - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results.""" @abstractmethod diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 3a636e6a..1d525bdb 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -438,6 +438,8 @@ class PGVectorStorage(BaseVectorStorage): "entity_name": item["entity_name"], "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), + "chunk_id": item["source_id"], + # TODO: add document_id } return upsert_sql, data @@ -450,6 +452,8 @@ class PGVectorStorage(BaseVectorStorage): "target_id": item["tgt_id"], "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), + "chunk_id": item["source_id"], + # TODO: add document_id } return upsert_sql, data @@ -492,13 +496,20 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) #################### query method ############### - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) + if ids: + formatted_ids = ",".join(f"'{id}'" for id in ids) + else: + formatted_ids = "NULL" + sql = SQL_TEMPLATES[self.base_namespace].format( - embedding_string=embedding_string + embedding_string=embedding_string, doc_ids=formatted_ids ) params = { "workspace": self.db.workspace, @@ -1491,6 +1502,7 @@ TABLES = { content_vector VECTOR, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, + chunk_id VARCHAR(255) NULL, CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) )""" }, @@ -1504,6 +1516,7 @@ TABLES = { content_vector VECTOR, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, + chunk_id VARCHAR(255) NULL, CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) )""" }, @@ -1586,8 +1599,9 @@ SQL_TEMPLATES = { content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP """, - "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector) - VALUES ($1, $2, $3, $4, $5) + "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, + content_vector, chunk_id) + VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (workspace,id) DO UPDATE SET entity_name=EXCLUDED.entity_name, content=EXCLUDED.content, @@ -1595,8 +1609,8 @@ SQL_TEMPLATES = { update_time=CURRENT_TIMESTAMP """, "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, - target_id, content, content_vector) - VALUES ($1, $2, $3, $4, $5, $6) + target_id, content, content_vector, chunk_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (workspace,id) DO UPDATE SET source_id=EXCLUDED.source_id, target_id=EXCLUDED.target_id, @@ -1604,21 +1618,21 @@ SQL_TEMPLATES = { content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP """, # SQL for VectorStorage - "entities": """SELECT entity_name FROM - (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_ENTITY where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, - "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM - (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_RELATION where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, - "chunks": """SELECT id FROM - (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, + # "entities": """SELECT entity_name FROM + # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + # FROM LIGHTRAG_VDB_ENTITY where workspace=$1) + # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + # """, + # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM + # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + # FROM LIGHTRAG_VDB_RELATION where workspace=$1) + # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + # """, + # "chunks": """SELECT id FROM + # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) + # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + # """, # DROP tables "drop_all": """ DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE; @@ -1642,4 +1656,55 @@ SQL_TEMPLATES = { "drop_vdb_relation": """ DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE; """, + "relationships": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) + ) + SELECT source_id as src_id, target_id as tgt_id + FROM ( + SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_RELATION r + WHERE r.workspace=$1 + AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks) + ) filtered + WHERE distance>$2 + ORDER BY distance DESC + LIMIT $3 + """, + "entities": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) + ) + SELECT entity_name FROM + ( + SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_ENTITY + where workspace=$1 + AND chunk_id IN (SELECT chunk_id FROM relevant_chunks) + ) + WHERE distance>$2 + ORDER BY distance DESC + LIMIT $3 + """, + "chunks": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) + ) + SELECT id FROM + ( + SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_DOC_CHUNKS + where workspace=$1 + AND id IN (SELECT chunk_id FROM relevant_chunks) + ) + WHERE distance>$2 + ORDER BY distance DESC + LIMIT $3 + """, } diff --git a/lightrag/operate.py b/lightrag/operate.py index 09e51fcf..5baec1eb 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -962,7 +962,10 @@ async def mix_kg_vector_query( try: # Reduce top_k for vector search in hybrid mode since we have structured information from KG mix_topk = min(10, query_param.top_k) - results = await chunks_vdb.query(augmented_query, top_k=mix_topk) + # TODO: add ids to the query + results = await chunks_vdb.query( + augmented_query, top_k=mix_topk, ids=query_param.ids + ) if not results: return None @@ -1171,7 +1174,11 @@ async def _get_node_data( logger.info( f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" ) - results = await entities_vdb.query(query, top_k=query_param.top_k) + + results = await entities_vdb.query( + query, top_k=query_param.top_k, ids=query_param.ids + ) + if not len(results): return "", "", "" # get entity information @@ -1424,7 +1431,10 @@ async def _get_edge_data( logger.info( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" ) - results = await relationships_vdb.query(keywords, top_k=query_param.top_k) + + results = await relationships_vdb.query( + keywords, top_k=query_param.top_k, ids=query_param.ids + ) if not len(results): return "", "", "" @@ -1673,7 +1683,9 @@ async def naive_query( if cached_response is not None: return cached_response - results = await chunks_vdb.query(query, top_k=query_param.top_k) + results = await chunks_vdb.query( + query, top_k=query_param.top_k, ids=query_param.ids + ) if not len(results): return PROMPTS["fail_response"]