From 8cfca5a141e615deee13bbb5557b77825784dd8a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 03:29:40 +0800 Subject: [PATCH] Fix linting --- lightrag/kg/milvus_impl.py | 28 +++++++++++++++++++++++----- lightrag/kg/mongo_impl.py | 36 +++++++++++++++++++++++++++++++----- lightrag/kg/neo4j_impl.py | 14 +++++++++++--- lightrag/kg/qdrant_impl.py | 9 +++++++-- lightrag/kg/redis_impl.py | 5 ++++- 5 files changed, 76 insertions(+), 16 deletions(-) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 35260652..ae0daac2 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -16,6 +16,7 @@ from pymilvus import MilvusClient config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class MilvusVectorDBStorge(BaseVectorStorage): @staticmethod @@ -30,11 +31,28 @@ class MilvusVectorDBStorge(BaseVectorStorage): def __post_init__(self): self._client = MilvusClient( - uri = os.environ.get("MILVUS_URI", config.get("milvus", "uri", fallback=os.path.join(self.global_config["working_dir"], "milvus_lite.db"))), - user = os.environ.get("MILVUS_USER", config.get("milvus", "user", fallback=None)), - password = os.environ.get("MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)), - token = os.environ.get("MILVUS_TOKEN", config.get("milvus", "token", fallback=None)), - db_name = os.environ.get("MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)), + uri=os.environ.get( + "MILVUS_URI", + config.get( + "milvus", + "uri", + fallback=os.path.join( + self.global_config["working_dir"], "milvus_lite.db" + ), + ), + ), + user=os.environ.get( + "MILVUS_USER", config.get("milvus", "user", fallback=None) + ), + password=os.environ.get( + "MILVUS_PASSWORD", config.get("milvus", "password", fallback=None) + ), + token=os.environ.get( + "MILVUS_TOKEN", config.get("milvus", "token", fallback=None) + ), + db_name=os.environ.get( + "MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None) + ), ) self._max_batch_size = self.global_config["embedding_batch_num"] MilvusVectorDBStorge.create_collection_if_not_exist( diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 142050cf..017f57bd 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -22,13 +22,24 @@ from ..utils import logger config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): client = MongoClient( - os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) + os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) + ) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + ) ) - database = client.get_database(os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))) self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") @@ -91,10 +102,25 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self.client = AsyncIOMotorClient( - os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) + os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) ) - self.db = self.client[os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))] - self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"))] + self.db = self.client[ + os.environ.get( + "MONGO_DATABASE", + mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + ) + ] + self.collection = self.db[ + os.environ.get( + "MONGO_KG_COLLECTION", + config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"), + ) + ] # # ------------------------------------------------------------------------- diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 61f85ad4..dfb58825 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -31,6 +31,7 @@ from ..base import BaseGraphStorage config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -47,9 +48,16 @@ class Neo4JStorage(BaseGraphStorage): self._driver_lock = asyncio.Lock() URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)] - USERNAME = os.environ["NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)] - PASSWORD = os.environ["NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)] - MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", config.get("neo4j", "connection_pool_size", fallback=800)) + USERNAME = os.environ[ + "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) + ] + PASSWORD = os.environ[ + "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None) + ] + MAX_CONNECTION_POOL_SIZE = os.environ.get( + "NEO4J_MAX_CONNECTION_POOL_SIZE", + config.get("neo4j", "connection_pool_size", fallback=800), + ) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) ) diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 0a971b73..bda23f8d 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -19,6 +19,7 @@ from qdrant_client import QdrantClient, models config = configparser.ConfigParser() config.read("config.ini", "utf-8") + def compute_mdhash_id_for_qdrant( content: str, prefix: str = "", style: str = "simple" ) -> str: @@ -59,8 +60,12 @@ class QdrantVectorDBStorage(BaseVectorStorage): def __post_init__(self): self._client = QdrantClient( - url=os.environ.get("QDRANT_URL", config.get("qdrant", "uri", fallback=None)), - api_key=os.environ.get("QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)), + url=os.environ.get( + "QDRANT_URL", config.get("qdrant", "uri", fallback=None) + ), + api_key=os.environ.get( + "QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None) + ), ) self._max_batch_size = self.global_config["embedding_batch_num"] QdrantVectorDBStorage.create_collection_if_not_exist( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index fef62d5f..ed8f46f9 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -18,10 +18,13 @@ import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): - redis_url = os.environ.get("REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")) + redis_url = os.environ.get( + "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") + ) self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}")