From 7ec769456cc561e6c221933b91269ec2d92b287a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 03:55:15 +0800 Subject: [PATCH] Inject Postgres to LightRag storage class when needed --- lightrag/lightrag.py | 80 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6b9161be..3603509a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,7 @@ from .base import ( StorageNameSpace, ) from .kg.oracle_impl import OracleDB +from .kg.postgres_impl import PostgreSQLDB from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, @@ -394,6 +395,18 @@ class LightRAG: self.graph_storage_cls, global_config=global_config ) + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace=self.namespace_prefix + "json_doc_status_storage", + embedding_func=None, + ) + + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), + embedding_func=self.embedding_func, + ) + # 检查是否使用了 Oracle 存储实现 if ( self.kv_storage == "OracleKVStorage" @@ -403,14 +416,16 @@ class LightRAG: # 从环境变量或配置文件获取参数 dbconfig = { "user": os.environ.get( - "ORACLE_USER", config.get("oracle", "user", fallback=None) + "ORACLE_USER", + config.get("oracle", "user", fallback=None), ), "password": os.environ.get( "ORACLE_PASSWORD", config.get("oracle", "password", fallback=None), ), "dsn": os.environ.get( - "ORACLE_DSN", config.get("oracle", "dsn", fallback=None) + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), ), "config_dir": os.environ.get( "ORACLE_CONFIG_DIR", @@ -441,17 +456,58 @@ class LightRAG: if self.graph_storage == "OracleGraphStorage": self.graph_storage_cls.db = oracle_db - self.json_doc_status_storage = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "json_doc_status_storage", - embedding_func=None, - ) + # 检查是否使用了 PostgreSQL 存储实现 + if ( + self.kv_storage == "PGKVStorage" + or self.vector_storage == "PGVectorStorage" + or self.graph_storage == "PGGraphStorage" + or self.json_doc_status_storage == "PGDocStatusStorage" + ): + # 读取配置文件 + config_parser = configparser.ConfigParser() + if os.path.exists("config.ini"): + config_parser.read("config.ini") - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - embedding_func=self.embedding_func, - ) + # 从环境变量或配置文件获取参数 + dbconfig = { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + # 初始化 PostgreSQLDB 对象 + postgres_db = PostgreSQLDB(dbconfig) + loop = always_get_an_event_loop() + loop.run_until_complete(postgres_db.initdb()) + + # 只对 PostgreSQL 实现的存储类注入 db 对象 + if self.kv_storage == "PGKVStorage": + self.key_string_value_json_storage_cls.db = postgres_db + if self.vector_storage == "PGVectorStorage": + self.vector_db_storage_cls.db = postgres_db + if self.graph_storage == "PGGraphStorage": + self.graph_storage_cls.db = postgres_db + if self.json_doc_status_storage == "OracleGraphStorage": + self.json_doc_status_storage = postgres_db #### # add embedding func by walter