Fix linting

This commit is contained in:
yangdx
2025-02-13 04:12:00 +08:00
parent d25386ff1b
commit ed73ea4076
11 changed files with 64 additions and 30 deletions

View File

@@ -66,12 +66,14 @@ load_dotenv(override=True)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini") config.read("config.ini")
class DefaultRAGStorageConfig: class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage" KV_STORAGE = "JsonKVStorage"
VECTOR_STORAGE = "NanoVectorDBStorage" VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage" GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage" DOC_STATUS_STORAGE = "JsonDocStatusStorage"
# Global progress tracker # Global progress tracker
scan_progress: Dict = { scan_progress: Dict = {
"is_scanning": False, "is_scanning": False,
@@ -317,22 +319,30 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--kv-storage", "--kv-storage",
default=get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE), default=get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
),
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
) )
parser.add_argument( parser.add_argument(
"--doc-status-storage", "--doc-status-storage",
default=get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE), default=get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
),
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
) )
parser.add_argument( parser.add_argument(
"--graph-storage", "--graph-storage",
default=get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE), default=get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
),
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
) )
parser.add_argument( parser.add_argument(
"--vector-storage", "--vector-storage",
default=get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE), default=get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
),
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
) )
@@ -725,7 +735,12 @@ def create_app(args):
for storage_name, storage_instance in storage_instances: for storage_name, storage_instance in storage_instances:
if isinstance( if isinstance(
storage_instance, storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), (
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
),
): ):
storage_instance.db = postgres_db storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}") logger.info(f"Injected postgres_db to {storage_name}")
@@ -790,11 +805,11 @@ def create_app(args):
if postgres_db and hasattr(postgres_db, "pool"): if postgres_db and hasattr(postgres_db, "pool"):
await postgres_db.pool.close() await postgres_db.pool.close()
logger.info("Closed PostgreSQL connection pool") logger.info("Closed PostgreSQL connection pool")
if oracle_db and hasattr(oracle_db, "pool"): if oracle_db and hasattr(oracle_db, "pool"):
await oracle_db.pool.close() await oracle_db.pool.close()
logger.info("Closed Oracle connection pool") logger.info("Closed Oracle connection pool")
if tidb_db and hasattr(tidb_db, "pool"): if tidb_db and hasattr(tidb_db, "pool"):
await tidb_db.pool.close() await tidb_db.pool.close()
logger.info("Closed TiDB connection pool") logger.info("Closed TiDB connection pool")

View File

@@ -1,4 +1,3 @@
import os
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
@@ -20,7 +19,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
user_collection_settings = config.get("collection_settings", {}) user_collection_settings = config.get("collection_settings", {})

View File

@@ -30,7 +30,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
# Where to save index file if you want persistent storage # Where to save index file if you want persistent storage

View File

@@ -35,7 +35,9 @@ class MilvusVectorDBStorge(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
self._client = MilvusClient( self._client = MilvusClient(
@@ -111,7 +113,10 @@ class MilvusVectorDBStorge(BaseVectorStorage):
data=embedding, data=embedding,
limit=top_k, limit=top_k,
output_fields=list(self.meta_fields), output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}}, search_params={
"metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold},
},
) )
print(results) print(results)
return [ return [

View File

@@ -82,7 +82,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
self._client_file_name = os.path.join( self._client_file_name = os.path.join(

View File

@@ -1,6 +1,5 @@
import array import array
import asyncio import asyncio
import os
# import html # import html
# import os # import os
@@ -326,7 +325,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):

View File

@@ -306,7 +306,9 @@ class PGVectorStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
def _upsert_chunks(self, item: dict): def _upsert_chunks(self, item: dict):
@@ -424,9 +426,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data]) keys = ",".join([f"'{_id}'" for _id in data])
sql = ( sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
)
result = await self.db.query(sql, multirows=True) result = await self.db.query(sql, multirows=True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None: if result is None:

View File

@@ -64,7 +64,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
self._client = QdrantClient( self._client = QdrantClient(
@@ -140,5 +142,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
logger.debug(f"query result: {results}") logger.debug(f"query result: {results}")
# 添加余弦相似度过滤 # 添加余弦相似度过滤
filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold] filtered_results = [
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results] dp for dp in results if dp.score >= self.cosine_better_than_threshold
]
return [
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
]

View File

@@ -222,7 +222,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]: async def query(self, query: str, top_k: int) -> list[dict]:

View File

@@ -426,7 +426,7 @@ class LightRAG:
} }
self.vector_db_storage_cls_kwargs = { self.vector_db_storage_cls_kwargs = {
**default_vector_db_kwargs, **default_vector_db_kwargs,
**self.vector_db_storage_cls_kwargs **self.vector_db_storage_cls_kwargs,
} }
# show config # show config
@@ -532,8 +532,6 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
self.llm_model_func, self.llm_model_func,

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import json import json
import os
import re import re
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from typing import Any, Union from typing import Any, Union
@@ -35,7 +34,6 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS
import time import time
def chunking_by_token_size( def chunking_by_token_size(
content: str, content: str,
split_by_character: Union[str, None] = None, split_by_character: Union[str, None] = None,
@@ -1057,7 +1055,9 @@ async def _get_node_data(
query_param: QueryParam, query_param: QueryParam,
): ):
# get similar entities # get similar entities
logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}") logger.info(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
)
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return "", "", "" return "", "", ""
@@ -1273,7 +1273,9 @@ async def _get_edge_data(
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
): ):
logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}") logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k) results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results): if not len(results):