Update project dependencies and example test files

- Updated requirements.txt with latest package versions
- Added support for filtering query results by IDs in base and operate modules
- Modified PostgreSQL vector storage to include document and chunk ID fields
This commit is contained in:
Roy
2025-03-07 18:45:28 +00:00
parent 5e7ef39998
commit 0ec61d6407
6 changed files with 91 additions and 20 deletions

View File

@@ -81,6 +81,9 @@ class QueryParam:
history_turns: int = 3
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
ids: list[str] | None = None
"""List of ids to filter the results."""
@dataclass
class StorageNameSpace(ABC):
@@ -107,7 +110,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
meta_fields: set[str] = field(default_factory=set)
@abstractmethod
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results."""
@abstractmethod

View File

@@ -492,7 +492,7 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data)
#################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding))
@@ -1387,6 +1387,8 @@ TABLES = {
content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
document_id VARCHAR(255) NULL,
chunk_id VARCHAR(255) NULL,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1400,6 +1402,8 @@ TABLES = {
content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
document_id VARCHAR(255) NULL,
chunk_id VARCHAR(255) NULL,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)"""
},

View File

@@ -1243,6 +1243,7 @@ class LightRAG:
embedding_func=self.embedding_func,
),
system_prompt=system_prompt,
ids = param.ids
)
elif param.mode == "naive":
response = await naive_query(

View File

@@ -602,6 +602,7 @@ async def kg_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
ids: list[str] | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -649,6 +650,7 @@ async def kg_query(
relationships_vdb,
text_chunks_db,
query_param,
ids
)
if query_param.only_need_context:
@@ -1016,6 +1018,7 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
ids: list[str] = None,
):
if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data(
@@ -1032,6 +1035,7 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
ids = ids
)
else: # hybrid mode
ll_data, hl_data = await asyncio.gather(
@@ -1348,11 +1352,16 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
ids: list[str] | None = None,
):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if ids:
#TODO: add ids to the query
results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = ids)
else:
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
return "", "", ""