From ed73ea407643a9c004fed56b9383ee42ce741e66 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 04:12:00 +0800 Subject: [PATCH] Fix linting --- lightrag/api/lightrag_server.py | 29 ++++++++++++++++++++++------- lightrag/kg/chroma_impl.py | 5 +++-- lightrag/kg/faiss_impl.py | 4 +++- lightrag/kg/milvus_impl.py | 9 +++++++-- lightrag/kg/nano_vector_db_impl.py | 4 +++- lightrag/kg/oracle_impl.py | 5 +++-- lightrag/kg/postgres_impl.py | 8 ++++---- lightrag/kg/qdrant_impl.py | 12 +++++++++--- lightrag/kg/tidb_impl.py | 4 +++- lightrag/lightrag.py | 4 +--- lightrag/operate.py | 10 ++++++---- 11 files changed, 64 insertions(+), 30 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1f531c4f..b8e4f1e6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -66,12 +66,14 @@ load_dotenv(override=True) config = configparser.ConfigParser() config.read("config.ini") + class DefaultRAGStorageConfig: KV_STORAGE = "JsonKVStorage" VECTOR_STORAGE = "NanoVectorDBStorage" GRAPH_STORAGE = "NetworkXStorage" DOC_STATUS_STORAGE = "JsonDocStatusStorage" + # Global progress tracker scan_progress: Dict = { "is_scanning": False, @@ -317,22 +319,30 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--kv-storage", - default=get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE), + default=get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ), help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", ) parser.add_argument( "--doc-status-storage", - default=get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE), + default=get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ), help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", ) parser.add_argument( "--graph-storage", - default=get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE), + default=get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ), help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", ) parser.add_argument( "--vector-storage", - default=get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE), + default=get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ), help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", ) @@ -725,7 +735,12 @@ def create_app(args): for storage_name, storage_instance in storage_instances: if isinstance( storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ( + PGKVStorage, + PGVectorStorage, + PGGraphStorage, + PGDocStatusStorage, + ), ): storage_instance.db = postgres_db logger.info(f"Injected postgres_db to {storage_name}") @@ -790,11 +805,11 @@ def create_app(args): if postgres_db and hasattr(postgres_db, "pool"): await postgres_db.pool.close() logger.info("Closed PostgreSQL connection pool") - + if oracle_db and hasattr(oracle_db, "pool"): await oracle_db.pool.close() logger.info("Closed Oracle connection pool") - + if tidb_db and hasattr(tidb_db, "pool"): await tidb_db.pool.close() logger.info("Closed TiDB connection pool") diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 242c93ea..82e723a1 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,4 +1,3 @@ -import os import asyncio from dataclasses import dataclass from typing import Union @@ -20,7 +19,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold user_collection_settings = config.get("collection_settings", {}) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 47111a47..0dca9e4c 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -30,7 +30,9 @@ class FaissVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold # Where to save index file if you want persistent storage diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index dd50c026..1abec502 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -35,7 +35,9 @@ class MilvusVectorDBStorge(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold self._client = MilvusClient( @@ -111,7 +113,10 @@ class MilvusVectorDBStorge(BaseVectorStorage): data=embedding, limit=top_k, output_fields=list(self.meta_fields), - search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}}, + search_params={ + "metric_type": "COSINE", + "params": {"radius": self.cosine_better_than_threshold}, + }, ) print(results) return [ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 5a61bf4f..2db8f72a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -82,7 +82,9 @@ class NanoVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold self._client_file_name = os.path.join( diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 5a1e0616..65f1060c 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -1,6 +1,5 @@ import array import asyncio -import os # import html # import os @@ -326,7 +325,9 @@ class OracleVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold async def upsert(self, data: dict[str, dict]): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index dde88739..cb636d7f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -306,7 +306,9 @@ class PGVectorStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold def _upsert_chunks(self, item: dict): @@ -424,9 +426,7 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that don't exist in storage""" keys = ",".join([f"'{_id}'" for _id in data]) - sql = ( - f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" - ) + sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" result = await self.db.query(sql, multirows=True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 88dce27f..7c9f21a0 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -64,7 +64,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold self._client = QdrantClient( @@ -140,5 +142,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) logger.debug(f"query result: {results}") # 添加余弦相似度过滤 - filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold] - return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results] + filtered_results = [ + dp for dp in results if dp.score >= self.cosine_better_than_threshold + ] + return [ + {**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results + ] diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 248f2c85..00b8003d 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -222,7 +222,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold async def query(self, query: str, top_k: int) -> list[dict]: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 66508faf..cdb0462e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -426,7 +426,7 @@ class LightRAG: } self.vector_db_storage_cls_kwargs = { **default_vector_db_kwargs, - **self.vector_db_storage_cls_kwargs + **self.vector_db_storage_cls_kwargs, } # show config @@ -532,8 +532,6 @@ class LightRAG: embedding_func=self.embedding_func, ) - - self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, diff --git a/lightrag/operate.py b/lightrag/operate.py index f8d484af..04aad0d4 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,6 +1,5 @@ import asyncio import json -import os import re from tqdm.asyncio import tqdm as tqdm_async from typing import Any, Union @@ -35,7 +34,6 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS import time - def chunking_by_token_size( content: str, split_by_character: Union[str, None] = None, @@ -1057,7 +1055,9 @@ async def _get_node_data( query_param: QueryParam, ): # get similar entities - logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}") + logger.info( + f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" + ) results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return "", "", "" @@ -1273,7 +1273,9 @@ async def _get_edge_data( text_chunks_db: BaseKVStorage, query_param: QueryParam, ): - logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}") + logger.info( + f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" + ) results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results):