diff --git a/lightrag/base.py b/lightrag/base.py index e7ab3127..20fe2a5b 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -108,9 +108,8 @@ class BaseVectorStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc cosine_better_than_threshold: float = field(default=0.2) meta_fields: set[str] = field(default_factory=set) - @abstractmethod - async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]: + async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results.""" @abstractmethod diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 3fc05f59..c1ca7aa9 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -439,6 +439,7 @@ class PGVectorStorage(BaseVectorStorage): "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), "chunk_id": item["source_id"], + #TODO: add document_id } return upsert_sql, data @@ -452,6 +453,7 @@ class PGVectorStorage(BaseVectorStorage): "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), "chunk_id": item["source_id"] + #TODO: add document_id } return upsert_sql, data @@ -494,13 +496,19 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) #################### query method ############### - async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]: + async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) + if ids: + formatted_ids = ",".join(f"'{id}'" for id in ids) + else: + formatted_ids = "NULL" + sql = SQL_TEMPLATES[self.base_namespace].format( - embedding_string=embedding_string + embedding_string=embedding_string, + doc_ids=formatted_ids ) params = { "workspace": self.db.workspace, @@ -1389,7 +1397,6 @@ TABLES = { content_vector VECTOR, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, - document_id VARCHAR(255) NULL, chunk_id VARCHAR(255) NULL, CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) )""" @@ -1404,7 +1411,6 @@ TABLES = { content_vector VECTOR, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, - document_id VARCHAR(255) NULL, chunk_id VARCHAR(255) NULL, CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) )""" @@ -1507,21 +1513,21 @@ SQL_TEMPLATES = { content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP """, # SQL for VectorStorage - "entities": """SELECT entity_name FROM - (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_ENTITY where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, - "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM - (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_RELATION where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, - "chunks": """SELECT id FROM - (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, + # "entities": """SELECT entity_name FROM + # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + # FROM LIGHTRAG_VDB_ENTITY where workspace=$1) + # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + # """, + # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM + # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + # FROM LIGHTRAG_VDB_RELATION where workspace=$1) + # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + # """, + # "chunks": """SELECT id FROM + # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) + # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + # """, # DROP tables "drop_all": """ DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE; @@ -1545,4 +1551,56 @@ SQL_TEMPLATES = { "drop_vdb_relation": """ DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE; """, -} + "relationships": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) + ) + SELECT source_id as src_id, target_id as tgt_id + FROM ( + SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_RELATION r + WHERE r.workspace=$1 + AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks) + ) filtered + WHERE distance>$2 + ORDER BY distance DESC + LIMIT $3 + """, + "entities": + ''' + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) + ) + SELECT entity_name FROM + ( + SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_ENTITY + where workspace=$1 + AND chunk_id IN (SELECT chunk_id FROM relevant_chunks) + ) + WHERE distance>$2 + ORDER BY distance DESC + LIMIT $3 + ''', + 'chunks': """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) + ) + SELECT id FROM + ( + SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_DOC_CHUNKS + where workspace=$1 + AND chunk_id IN (SELECT chunk_id FROM relevant_chunks) + ) + WHERE distance>$2 + ORDER BY distance DESC + LIMIT $3 + """ +} \ No newline at end of file diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ae6fd9dc..0554ab76 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1243,7 +1243,6 @@ class LightRAG: embedding_func=self.embedding_func, ), system_prompt=system_prompt, - ids = param.ids ) elif param.mode == "naive": response = await naive_query( diff --git a/lightrag/operate.py b/lightrag/operate.py index 6c0e1e4c..7910917a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -602,7 +602,6 @@ async def kg_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, - ids: list[str] | None = None, ) -> str | AsyncIterator[str]: # Handle cache use_model_func = global_config["llm_model_func"] @@ -650,7 +649,6 @@ async def kg_query( relationships_vdb, text_chunks_db, query_param, - ids ) if query_param.only_need_context: @@ -1035,7 +1033,6 @@ async def _build_query_context( relationships_vdb, text_chunks_db, query_param, - ids = ids ) else: # hybrid mode ll_data, hl_data = await asyncio.gather( @@ -1104,7 +1101,9 @@ async def _get_node_data( logger.info( f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" ) - results = await entities_vdb.query(query, top_k=query_param.top_k) + + results = await entities_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids) + if not len(results): return "", "", "" # get entity information @@ -1352,16 +1351,12 @@ async def _get_edge_data( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - ids: list[str] | None = None, ): logger.info( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" ) - if ids: - #TODO: add ids to the query - results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = ids) - else: - results = await relationships_vdb.query(keywords, top_k=query_param.top_k) + + results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = query_param.ids) if not len(results): return "", "", "" @@ -1610,7 +1605,7 @@ async def naive_query( if cached_response is not None: return cached_response - results = await chunks_vdb.query(query, top_k=query_param.top_k) + results = await chunks_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids) if not len(results): return PROMPTS["fail_response"]