Fix linting
This commit is contained in:
@@ -66,12 +66,14 @@ load_dotenv(override=True)
|
|||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini")
|
config.read("config.ini")
|
||||||
|
|
||||||
|
|
||||||
class DefaultRAGStorageConfig:
|
class DefaultRAGStorageConfig:
|
||||||
KV_STORAGE = "JsonKVStorage"
|
KV_STORAGE = "JsonKVStorage"
|
||||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||||
GRAPH_STORAGE = "NetworkXStorage"
|
GRAPH_STORAGE = "NetworkXStorage"
|
||||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||||
|
|
||||||
|
|
||||||
# Global progress tracker
|
# Global progress tracker
|
||||||
scan_progress: Dict = {
|
scan_progress: Dict = {
|
||||||
"is_scanning": False,
|
"is_scanning": False,
|
||||||
@@ -317,22 +319,30 @@ def parse_args() -> argparse.Namespace:
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-storage",
|
"--kv-storage",
|
||||||
default=get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE),
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||||
|
),
|
||||||
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--doc-status-storage",
|
"--doc-status-storage",
|
||||||
default=get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE),
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||||
|
),
|
||||||
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--graph-storage",
|
"--graph-storage",
|
||||||
default=get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE),
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||||
|
),
|
||||||
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vector-storage",
|
"--vector-storage",
|
||||||
default=get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE),
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||||
|
),
|
||||||
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -725,7 +735,12 @@ def create_app(args):
|
|||||||
for storage_name, storage_instance in storage_instances:
|
for storage_name, storage_instance in storage_instances:
|
||||||
if isinstance(
|
if isinstance(
|
||||||
storage_instance,
|
storage_instance,
|
||||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
(
|
||||||
|
PGKVStorage,
|
||||||
|
PGVectorStorage,
|
||||||
|
PGGraphStorage,
|
||||||
|
PGDocStatusStorage,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
storage_instance.db = postgres_db
|
storage_instance.db = postgres_db
|
||||||
logger.info(f"Injected postgres_db to {storage_name}")
|
logger.info(f"Injected postgres_db to {storage_name}")
|
||||||
@@ -790,11 +805,11 @@ def create_app(args):
|
|||||||
if postgres_db and hasattr(postgres_db, "pool"):
|
if postgres_db and hasattr(postgres_db, "pool"):
|
||||||
await postgres_db.pool.close()
|
await postgres_db.pool.close()
|
||||||
logger.info("Closed PostgreSQL connection pool")
|
logger.info("Closed PostgreSQL connection pool")
|
||||||
|
|
||||||
if oracle_db and hasattr(oracle_db, "pool"):
|
if oracle_db and hasattr(oracle_db, "pool"):
|
||||||
await oracle_db.pool.close()
|
await oracle_db.pool.close()
|
||||||
logger.info("Closed Oracle connection pool")
|
logger.info("Closed Oracle connection pool")
|
||||||
|
|
||||||
if tidb_db and hasattr(tidb_db, "pool"):
|
if tidb_db and hasattr(tidb_db, "pool"):
|
||||||
await tidb_db.pool.close()
|
await tidb_db.pool.close()
|
||||||
logger.info("Closed TiDB connection pool")
|
logger.info("Closed TiDB connection pool")
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -20,7 +19,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
user_collection_settings = config.get("collection_settings", {})
|
user_collection_settings = config.get("collection_settings", {})
|
||||||
|
@@ -30,7 +30,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
# Where to save index file if you want persistent storage
|
# Where to save index file if you want persistent storage
|
||||||
|
@@ -35,7 +35,9 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client = MilvusClient(
|
self._client = MilvusClient(
|
||||||
@@ -111,7 +113,10 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|||||||
data=embedding,
|
data=embedding,
|
||||||
limit=top_k,
|
limit=top_k,
|
||||||
output_fields=list(self.meta_fields),
|
output_fields=list(self.meta_fields),
|
||||||
search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}},
|
search_params={
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"params": {"radius": self.cosine_better_than_threshold},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
print(results)
|
print(results)
|
||||||
return [
|
return [
|
||||||
|
@@ -82,7 +82,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
import array
|
import array
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
|
|
||||||
# import html
|
# import html
|
||||||
# import os
|
# import os
|
||||||
@@ -326,7 +325,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
@@ -306,7 +306,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
def _upsert_chunks(self, item: dict):
|
def _upsert_chunks(self, item: dict):
|
||||||
@@ -424,9 +426,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
keys = ",".join([f"'{_id}'" for _id in data])
|
keys = ",".join([f"'{_id}'" for _id in data])
|
||||||
sql = (
|
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||||
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
|
||||||
)
|
|
||||||
result = await self.db.query(sql, multirows=True)
|
result = await self.db.query(sql, multirows=True)
|
||||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||||
if result is None:
|
if result is None:
|
||||||
|
@@ -64,7 +64,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client = QdrantClient(
|
self._client = QdrantClient(
|
||||||
@@ -140,5 +142,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
logger.debug(f"query result: {results}")
|
logger.debug(f"query result: {results}")
|
||||||
# 添加余弦相似度过滤
|
# 添加余弦相似度过滤
|
||||||
filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold]
|
filtered_results = [
|
||||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results]
|
dp for dp in results if dp.score >= self.cosine_better_than_threshold
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
|
||||||
|
]
|
||||||
|
@@ -222,7 +222,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
|
@@ -426,7 +426,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
self.vector_db_storage_cls_kwargs = {
|
self.vector_db_storage_cls_kwargs = {
|
||||||
**default_vector_db_kwargs,
|
**default_vector_db_kwargs,
|
||||||
**self.vector_db_storage_cls_kwargs
|
**self.vector_db_storage_cls_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
# show config
|
# show config
|
||||||
@@ -532,8 +532,6 @@ class LightRAG:
|
|||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
self.llm_model_func,
|
self.llm_model_func,
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
@@ -35,7 +34,6 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str,
|
content: str,
|
||||||
split_by_character: Union[str, None] = None,
|
split_by_character: Union[str, None] = None,
|
||||||
@@ -1057,7 +1055,9 @@ async def _get_node_data(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
# get similar entities
|
# get similar entities
|
||||||
logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}")
|
logger.info(
|
||||||
|
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||||||
|
)
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
@@ -1273,7 +1273,9 @@ async def _get_edge_data(
|
|||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}")
|
logger.info(
|
||||||
|
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||||
|
)
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
|
Reference in New Issue
Block a user