Merge pull request #1 from ArindamRoy23/document_query_filter
Document query filter
This commit is contained in:
21
.gitignore
vendored
21
.gitignore
vendored
@@ -64,3 +64,24 @@ gui/
|
|||||||
|
|
||||||
# unit-test files
|
# unit-test files
|
||||||
test_*
|
test_*
|
||||||
|
Miniconda3-latest-Linux-x86_64.sh
|
||||||
|
requirements_basic.txt
|
||||||
|
requirements.txt
|
||||||
|
examples/test_chromadb.py
|
||||||
|
examples/test_faiss.py
|
||||||
|
examples/test_neo4j.py
|
||||||
|
.gitignore
|
||||||
|
requirements.txt
|
||||||
|
examples/test_chromadb.py
|
||||||
|
examples/test_faiss.py
|
||||||
|
examples/*
|
||||||
|
tests/test_lightrag_ollama_chat.py
|
||||||
|
requirements.txt
|
||||||
|
requirements.txt
|
||||||
|
examples/test_chromadb.py
|
||||||
|
examples/test_faiss.py
|
||||||
|
examples/test_neo4j.py
|
||||||
|
tests/test_lightrag_ollama_chat.py
|
||||||
|
examples/test_chromadb.py
|
||||||
|
examples/test_faiss.py
|
||||||
|
examples/test_neo4j.py
|
||||||
|
@@ -81,6 +81,9 @@ class QueryParam:
|
|||||||
history_turns: int = 3
|
history_turns: int = 3
|
||||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||||
|
|
||||||
|
ids: list[str] | None = None
|
||||||
|
"""List of ids to filter the results."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace(ABC):
|
class StorageNameSpace(ABC):
|
||||||
@@ -105,9 +108,8 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
cosine_better_than_threshold: float = field(default=0.2)
|
cosine_better_than_threshold: float = field(default=0.2)
|
||||||
meta_fields: set[str] = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
|
||||||
"""Query the vector storage and retrieve top_k results."""
|
"""Query the vector storage and retrieve top_k results."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@@ -438,6 +438,8 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
"entity_name": item["entity_name"],
|
"entity_name": item["entity_name"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||||
|
"chunk_id": item["source_id"],
|
||||||
|
#TODO: add document_id
|
||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
@@ -450,6 +452,8 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
"target_id": item["tgt_id"],
|
"target_id": item["tgt_id"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||||
|
"chunk_id": item["source_id"]
|
||||||
|
#TODO: add document_id
|
||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
@@ -492,13 +496,19 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
await self.db.execute(upsert_sql, data)
|
await self.db.execute(upsert_sql, data)
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
embedding_string = ",".join(map(str, embedding))
|
embedding_string = ",".join(map(str, embedding))
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
formatted_ids = ",".join(f"'{id}'" for id in ids)
|
||||||
|
else:
|
||||||
|
formatted_ids = "NULL"
|
||||||
|
|
||||||
sql = SQL_TEMPLATES[self.base_namespace].format(
|
sql = SQL_TEMPLATES[self.base_namespace].format(
|
||||||
embedding_string=embedding_string
|
embedding_string=embedding_string,
|
||||||
|
doc_ids=formatted_ids
|
||||||
)
|
)
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
@@ -1387,6 +1397,7 @@ TABLES = {
|
|||||||
content_vector VECTOR,
|
content_vector VECTOR,
|
||||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
update_time TIMESTAMP,
|
update_time TIMESTAMP,
|
||||||
|
chunk_id VARCHAR(255) NULL,
|
||||||
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
||||||
)"""
|
)"""
|
||||||
},
|
},
|
||||||
@@ -1400,6 +1411,7 @@ TABLES = {
|
|||||||
content_vector VECTOR,
|
content_vector VECTOR,
|
||||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
update_time TIMESTAMP,
|
update_time TIMESTAMP,
|
||||||
|
chunk_id VARCHAR(255) NULL,
|
||||||
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
||||||
)"""
|
)"""
|
||||||
},
|
},
|
||||||
@@ -1482,8 +1494,9 @@ SQL_TEMPLATES = {
|
|||||||
content_vector=EXCLUDED.content_vector,
|
content_vector=EXCLUDED.content_vector,
|
||||||
update_time = CURRENT_TIMESTAMP
|
update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
content_vector, chunk_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET entity_name=EXCLUDED.entity_name,
|
SET entity_name=EXCLUDED.entity_name,
|
||||||
content=EXCLUDED.content,
|
content=EXCLUDED.content,
|
||||||
@@ -1491,8 +1504,8 @@ SQL_TEMPLATES = {
|
|||||||
update_time=CURRENT_TIMESTAMP
|
update_time=CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
||||||
target_id, content, content_vector)
|
target_id, content, content_vector, chunk_id)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET source_id=EXCLUDED.source_id,
|
SET source_id=EXCLUDED.source_id,
|
||||||
target_id=EXCLUDED.target_id,
|
target_id=EXCLUDED.target_id,
|
||||||
@@ -1500,21 +1513,21 @@ SQL_TEMPLATES = {
|
|||||||
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"entities": """SELECT entity_name FROM
|
# "entities": """SELECT entity_name FROM
|
||||||
(SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
# (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
# FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
||||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||||
""",
|
# """,
|
||||||
"relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
# "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
||||||
(SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
# (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
# FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
||||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||||
""",
|
# """,
|
||||||
"chunks": """SELECT id FROM
|
# "chunks": """SELECT id FROM
|
||||||
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
# (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
# FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
||||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||||
""",
|
# """,
|
||||||
# DROP tables
|
# DROP tables
|
||||||
"drop_all": """
|
"drop_all": """
|
||||||
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
||||||
@@ -1538,4 +1551,56 @@ SQL_TEMPLATES = {
|
|||||||
"drop_vdb_relation": """
|
"drop_vdb_relation": """
|
||||||
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
||||||
""",
|
""",
|
||||||
}
|
"relationships": """
|
||||||
|
WITH relevant_chunks AS (
|
||||||
|
SELECT id as chunk_id
|
||||||
|
FROM LIGHTRAG_DOC_CHUNKS
|
||||||
|
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
||||||
|
)
|
||||||
|
SELECT source_id as src_id, target_id as tgt_id
|
||||||
|
FROM (
|
||||||
|
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
|
FROM LIGHTRAG_VDB_RELATION r
|
||||||
|
WHERE r.workspace=$1
|
||||||
|
AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
||||||
|
) filtered
|
||||||
|
WHERE distance>$2
|
||||||
|
ORDER BY distance DESC
|
||||||
|
LIMIT $3
|
||||||
|
""",
|
||||||
|
"entities":
|
||||||
|
'''
|
||||||
|
WITH relevant_chunks AS (
|
||||||
|
SELECT id as chunk_id
|
||||||
|
FROM LIGHTRAG_DOC_CHUNKS
|
||||||
|
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
||||||
|
)
|
||||||
|
SELECT entity_name FROM
|
||||||
|
(
|
||||||
|
SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
|
FROM LIGHTRAG_VDB_ENTITY
|
||||||
|
where workspace=$1
|
||||||
|
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
||||||
|
)
|
||||||
|
WHERE distance>$2
|
||||||
|
ORDER BY distance DESC
|
||||||
|
LIMIT $3
|
||||||
|
''',
|
||||||
|
'chunks': """
|
||||||
|
WITH relevant_chunks AS (
|
||||||
|
SELECT id as chunk_id
|
||||||
|
FROM LIGHTRAG_DOC_CHUNKS
|
||||||
|
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
||||||
|
)
|
||||||
|
SELECT id FROM
|
||||||
|
(
|
||||||
|
SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
|
FROM LIGHTRAG_DOC_CHUNKS
|
||||||
|
where workspace=$1
|
||||||
|
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
||||||
|
)
|
||||||
|
WHERE distance>$2
|
||||||
|
ORDER BY distance DESC
|
||||||
|
LIMIT $3
|
||||||
|
"""
|
||||||
|
}
|
@@ -892,7 +892,8 @@ async def mix_kg_vector_query(
|
|||||||
try:
|
try:
|
||||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||||
mix_topk = min(10, query_param.top_k)
|
mix_topk = min(10, query_param.top_k)
|
||||||
results = await chunks_vdb.query(augmented_query, top_k=mix_topk)
|
# TODO: add ids to the query
|
||||||
|
results = await chunks_vdb.query(augmented_query, top_k=mix_topk, ids = query_param.ids)
|
||||||
if not results:
|
if not results:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1016,6 +1017,7 @@ async def _build_query_context(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
|
ids: list[str] = None,
|
||||||
):
|
):
|
||||||
if query_param.mode == "local":
|
if query_param.mode == "local":
|
||||||
entities_context, relations_context, text_units_context = await _get_node_data(
|
entities_context, relations_context, text_units_context = await _get_node_data(
|
||||||
@@ -1100,7 +1102,9 @@ async def _get_node_data(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||||||
)
|
)
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
|
||||||
|
results = await entities_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
# get entity information
|
# get entity information
|
||||||
@@ -1352,7 +1356,8 @@ async def _get_edge_data(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||||
)
|
)
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
|
||||||
|
results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = query_param.ids)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
@@ -1601,7 +1606,7 @@ async def naive_query(
|
|||||||
if cached_response is not None:
|
if cached_response is not None:
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
results = await chunks_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
@@ -1,17 +1,50 @@
|
|||||||
aiohttp
|
aioboto3==14.1.0
|
||||||
configparser
|
aiofiles==24.1.0
|
||||||
future
|
aiohttp==3.11.13
|
||||||
|
ascii_colors==0.5.2
|
||||||
# Basic modules
|
asyncpg==0.30.0
|
||||||
gensim
|
chromadb==0.6.3
|
||||||
pipmaster
|
community==1.0.0b1
|
||||||
pydantic
|
docx==0.2.4
|
||||||
python-dotenv
|
# faiss
|
||||||
|
fastapi==0.115.11
|
||||||
setuptools
|
glm==0.4.4
|
||||||
tenacity
|
graspologic==3.4.1
|
||||||
|
gunicorn==23.0.0
|
||||||
# LLM packages
|
httpx==0.28.1
|
||||||
tiktoken
|
imgui_bundle==1.6.2
|
||||||
|
jsonlines==4.0.0
|
||||||
# Extra libraries are installed when needed using pipmaster
|
llama_index==0.12.22
|
||||||
|
moderngl==5.12.0
|
||||||
|
motor==3.7.0
|
||||||
|
nano_vectordb==0.0.4.3
|
||||||
|
neo4j==5.28.1
|
||||||
|
nest_asyncio==1.6.0
|
||||||
|
networkx==3.4.2
|
||||||
|
numpy
|
||||||
|
openpyxl==3.1.5
|
||||||
|
oracledb==3.0.0
|
||||||
|
Pillow==11.1.0
|
||||||
|
pipmaster==0.4.0
|
||||||
|
protobuf
|
||||||
|
psutil==7.0.0
|
||||||
|
psycopg==3.2.5
|
||||||
|
psycopg_pool==3.2.6
|
||||||
|
pydantic==2.10.6
|
||||||
|
pymilvus==2.5.4
|
||||||
|
pymongo==4.11.2
|
||||||
|
PyPDF2==3.0.1
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
pyvis==0.3.2
|
||||||
|
qdrant_client==1.13.3
|
||||||
|
redis==5.2.1
|
||||||
|
Requests==2.32.3
|
||||||
|
sentence_transformers==3.4.1
|
||||||
|
setuptools==75.8.0
|
||||||
|
SQLAlchemy==2.0.38
|
||||||
|
starlette==0.46.0
|
||||||
|
tenacity==9.0.0
|
||||||
|
tiktoken==0.9.0
|
||||||
|
torch==2.6.0
|
||||||
|
transformers==4.49.0
|
||||||
|
uvicorn==0.34.0
|
||||||
|
@@ -38,16 +38,16 @@ class McpError(Exception):
|
|||||||
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"server": {
|
"server": {
|
||||||
"host": "localhost",
|
"host": "host.docker.internal",
|
||||||
"port": 9621,
|
"port": 11434,
|
||||||
"model": "lightrag:latest",
|
"model": "llama3.2:latest",
|
||||||
"timeout": 300,
|
"timeout": 300,
|
||||||
"max_retries": 1,
|
"max_retries": 1,
|
||||||
"retry_delay": 1,
|
"retry_delay": 1,
|
||||||
},
|
},
|
||||||
"test_cases": {
|
"test_cases": {
|
||||||
"basic": {"query": "唐僧有几个徒弟"},
|
"basic": {"query": "How many disciples did Tang Seng have?"},
|
||||||
"generate": {"query": "电视剧西游记导演是谁"},
|
"generate": {"query": "Who directed the TV series Journey to the West?"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -763,8 +763,8 @@ def parse_args() -> argparse.Namespace:
|
|||||||
Configuration file (config.json):
|
Configuration file (config.json):
|
||||||
{
|
{
|
||||||
"server": {
|
"server": {
|
||||||
"host": "localhost", # Server address
|
"host": "host.docker.internal", # Server address
|
||||||
"port": 9621, # Server port
|
"port": 11434, # Server port
|
||||||
"model": "lightrag:latest" # Default model name
|
"model": "lightrag:latest" # Default model name
|
||||||
},
|
},
|
||||||
"test_cases": {
|
"test_cases": {
|
||||||
|
Reference in New Issue
Block a user