From 3f845e9e532da84935c0a74a74738021356eb3f5 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 8 Feb 2025 16:05:59 +0800 Subject: [PATCH 1/2] better handling of namespace --- examples/copy_llm_cache_to_another_storage.py | 9 +-- lightrag/kg/mongo_impl.py | 10 +-- lightrag/kg/oracle_impl.py | 42 +++++++------ lightrag/kg/postgres_impl.py | 54 ++++++++-------- lightrag/kg/postgres_impl_test.py | 8 ++- lightrag/kg/tidb_impl.py | 61 ++++++++++++------- lightrag/lightrag.py | 40 ++++++------ lightrag/namespace.py | 25 ++++++++ 8 files changed, 156 insertions(+), 93 deletions(-) create mode 100644 lightrag/namespace.py diff --git a/examples/copy_llm_cache_to_another_storage.py b/examples/copy_llm_cache_to_another_storage.py index b9378c7c..5d07ad13 100644 --- a/examples/copy_llm_cache_to_another_storage.py +++ b/examples/copy_llm_cache_to_another_storage.py @@ -11,6 +11,7 @@ from dotenv import load_dotenv from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage from lightrag.storage import JsonKVStorage +from lightrag.namespace import NameSpace load_dotenv() ROOT_DIR = os.environ.get("ROOT_DIR") @@ -39,14 +40,14 @@ async def copy_from_postgres_to_json(): await postgres_db.initdb() from_llm_response_cache = PGKVStorage( - namespace="llm_response_cache", + namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE, global_config={"embedding_batch_num": 6}, embedding_func=None, db=postgres_db, ) to_llm_response_cache = JsonKVStorage( - namespace="llm_response_cache", + namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE, global_config={"working_dir": WORKING_DIR}, embedding_func=None, ) @@ -72,13 +73,13 @@ async def copy_from_json_to_postgres(): await postgres_db.initdb() from_llm_response_cache = JsonKVStorage( - namespace="llm_response_cache", + namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE, global_config={"working_dir": WORKING_DIR}, embedding_func=None, ) to_llm_response_cache = PGKVStorage( - namespace="llm_response_cache", + namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE, global_config={"embedding_batch_num": 6}, embedding_func=None, db=postgres_db, diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 7cef8c0f..7afc4240 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -13,10 +13,10 @@ if not pm.is_installed("motor"): from pymongo import MongoClient from motor.motor_asyncio import AsyncIOMotorClient from typing import Union, List, Tuple -from lightrag.utils import logger -from lightrag.base import BaseKVStorage -from lightrag.base import BaseGraphStorage +from ..utils import logger +from ..base import BaseKVStorage, BaseGraphStorage +from ..namespace import NameSpace, is_namespace @dataclass @@ -52,7 +52,7 @@ class MongoKVStorage(BaseKVStorage): return set([s for s in data if s not in existing_ids]) async def upsert(self, data: dict[str, dict]): - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): for mode, items in data.items(): for k, v in tqdm_async(items.items(), desc="Upserting"): key = f"{mode}_{k}" @@ -69,7 +69,7 @@ class MongoKVStorage(BaseKVStorage): return data async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): res = {} v = self._data.find_one({"_id": mode + "_" + id}) if v: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 8af01c47..32fbaa10 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -19,6 +19,7 @@ from ..base import ( BaseKVStorage, BaseVectorStorage, ) +from ..namespace import NameSpace, is_namespace import oracledb @@ -185,7 +186,7 @@ class OracleKVStorage(BaseKVStorage): SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: @@ -201,7 +202,7 @@ class OracleKVStorage(BaseKVStorage): """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: @@ -218,7 +219,7 @@ class OracleKVStorage(BaseKVStorage): params = {"workspace": self.db.workspace} # print("get_by_ids:"+SQL) res = await self.db.query(SQL, params, multirows=True) - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): modes = set() dict_res: dict[str, dict] = {} for row in res: @@ -256,7 +257,7 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) + table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]) ) params = {"workspace": self.db.workspace} res = await self.db.query(SQL, params, multirows=True) @@ -269,7 +270,7 @@ class OracleKVStorage(BaseKVStorage): ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): - if self.namespace.endswith("text_chunks"): + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { "id": k, @@ -302,7 +303,7 @@ class OracleKVStorage(BaseKVStorage): "status": item["status"], } await self.db.execute(merge_sql, _data) - if self.namespace.endswith("full_docs"): + if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): for k, v in data.items(): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] @@ -313,7 +314,7 @@ class OracleKVStorage(BaseKVStorage): } await self.db.execute(merge_sql, _data) - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): for mode, items in data.items(): for k, v in items.items(): upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] @@ -329,15 +330,16 @@ class OracleKVStorage(BaseKVStorage): return None async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) + SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace)) params = {"workspace": self.db.workspace, "id": id, "status": status} await self.db.execute(SQL, params) async def index_done_callback(self): - for n in ("full_docs", "text_chunks"): - if self.namespace.endswith(n): - logger.info("full doc and chunk data had been saved into oracle db!") - break + if is_namespace( + self.namespace, + (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), + ): + logger.info("full doc and chunk data had been saved into oracle db!") @dataclass @@ -614,13 +616,19 @@ class OracleGraphStorage(BaseGraphStorage): N_T = { - "full_docs": "LIGHTRAG_DOC_FULL", - "text_chunks": "LIGHTRAG_DOC_CHUNKS", - "chunks": "LIGHTRAG_DOC_CHUNKS", - "entities": "LIGHTRAG_GRAPH_NODES", - "relationships": "LIGHTRAG_GRAPH_EDGES", + NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", + NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES", + NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", } +def namespace_to_table_name(namespace: str) -> str: + for k, v in N_T.items(): + if is_namespace(namespace, k): + return v + + TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a031a0c3..8884d92e 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -32,6 +32,7 @@ from ..base import ( BaseGraphStorage, T, ) +from ..namespace import NameSpace, is_namespace if sys.platform.startswith("win"): import asyncio.windows_events @@ -187,7 +188,7 @@ class PGKVStorage(BaseKVStorage): """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) res = {} for row in array_res: @@ -203,7 +204,7 @@ class PGKVStorage(BaseKVStorage): """Specifically for llm_response_cache.""" sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] params = {"workspace": self.db.workspace, mode: mode, "id": id} - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) res = {} for row in array_res: @@ -219,7 +220,7 @@ class PGKVStorage(BaseKVStorage): ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) modes = set() dict_res: dict[str, dict] = {} @@ -239,7 +240,7 @@ class PGKVStorage(BaseKVStorage): return None async def all_keys(self) -> list[dict]: - if self.namespace.endswith("llm_response_cache"): + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): sql = "select workspace,mode,id from lightrag_llm_cache" res = await self.db.query(sql, multirows=True) return res @@ -251,7 +252,7 @@ class PGKVStorage(BaseKVStorage): async def filter_keys(self, keys: List[str]) -> Set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( - table_name=NAMESPACE_TABLE_MAP[self.namespace], + table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]), ) params = {"workspace": self.db.workspace} @@ -270,9 +271,9 @@ class PGKVStorage(BaseKVStorage): ################ INSERT METHODS ################ async def upsert(self, data: Dict[str, dict]): - if self.namespace.endswith("text_chunks"): + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): pass - elif self.namespace.endswith("full_docs"): + elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_doc_full"] _data = { @@ -281,7 +282,7 @@ class PGKVStorage(BaseKVStorage): "workspace": self.db.workspace, } await self.db.execute(upsert_sql, _data) - elif self.namespace.endswith("llm_response_cache"): + elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): for mode, items in data.items(): for k, v in items.items(): upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] @@ -296,12 +297,11 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) async def index_done_callback(self): - for n in ("full_docs", "text_chunks"): - if self.namespace.endswith(n): - logger.info( - "full doc and chunk data had been saved into postgresql db!" - ) - break + if is_namespace( + self.namespace, + (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), + ): + logger.info("full doc and chunk data had been saved into postgresql db!") @dataclass @@ -393,11 +393,11 @@ class PGVectorStorage(BaseVectorStorage): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] for item in list_data: - if self.namespace.endswith("chunks"): + if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): upsert_sql, data = self._upsert_chunks(item) - elif self.namespace.endswith("entities"): + elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): upsert_sql, data = self._upsert_entities(item) - elif self.namespace.endswith("relationships"): + elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): upsert_sql, data = self._upsert_relationships(item) else: raise ValueError(f"{self.namespace} is not supported") @@ -1027,16 +1027,22 @@ class PGGraphStorage(BaseGraphStorage): NAMESPACE_TABLE_MAP = { - "full_docs": "LIGHTRAG_DOC_FULL", - "text_chunks": "LIGHTRAG_DOC_CHUNKS", - "chunks": "LIGHTRAG_DOC_CHUNKS", - "entities": "LIGHTRAG_VDB_ENTITY", - "relationships": "LIGHTRAG_VDB_RELATION", - "doc_status": "LIGHTRAG_DOC_STATUS", - "llm_response_cache": "LIGHTRAG_LLM_CACHE", + NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", + NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", + NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION", + NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS", + NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE", } +def namespace_to_table_name(namespace: str) -> str: + for k, v in NAMESPACE_TABLE_MAP.items(): + if is_namespace(namespace, k): + return v + + TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py index eb6e6e73..304d556c 100644 --- a/lightrag/kg/postgres_impl_test.py +++ b/lightrag/kg/postgres_impl_test.py @@ -12,7 +12,9 @@ if not pm.is_installed("asyncpg"): import asyncpg import psycopg from psycopg_pool import AsyncConnectionPool -from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage + +from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage +from ..namespace import NameSpace DB = "rag" USER = "rag" @@ -76,7 +78,7 @@ db = PostgreSQLDB( async def query_with_age(): await db.initdb() graph = PGGraphStorage( - namespace="chunk_entity_relation", + namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, global_config={}, embedding_func=None, ) @@ -92,7 +94,7 @@ async def query_with_age(): async def create_edge_with_age(): await db.initdb() graph = PGGraphStorage( - namespace="chunk_entity_relation", + namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, global_config={}, embedding_func=None, ) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 4a21d067..cb819d47 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -14,8 +14,9 @@ if not pm.is_installed("sqlalchemy"): from sqlalchemy import create_engine, text from tqdm import tqdm -from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage -from lightrag.utils import logger +from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage +from ..utils import logger +from ..namespace import NameSpace, is_namespace class TiDB(object): @@ -138,8 +139,8 @@ class TiDBKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=N_T[self.namespace], - id_field=N_ID[self.namespace], + table_name=namespace_to_table_name(self.namespace), + id_field=namespace_to_id(self.namespace), ids=",".join([f"'{id}'" for id in keys]), ) try: @@ -160,7 +161,7 @@ class TiDBKVStorage(BaseKVStorage): async def upsert(self, data: dict[str, dict]): left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) - if self.namespace.endswith("text_chunks"): + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { "__id__": k, @@ -196,7 +197,7 @@ class TiDBKVStorage(BaseKVStorage): ) await self.db.execute(merge_sql, data) - if self.namespace.endswith("full_docs"): + if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): merge_sql = SQL_TEMPLATES["upsert_doc_full"] data = [] for k, v in self._data.items(): @@ -211,10 +212,11 @@ class TiDBKVStorage(BaseKVStorage): return left_data async def index_done_callback(self): - for n in ("full_docs", "text_chunks"): - if self.namespace.endswith(n): - logger.info("full doc and chunk data had been saved into TiDB db!") - break + if is_namespace( + self.namespace, + (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), + ): + logger.info("full doc and chunk data had been saved into TiDB db!") @dataclass @@ -260,7 +262,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): if not len(data): logger.warning("You insert an empty data to vector DB") return [] - if self.namespace.endswith("chunks"): + if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): return [] logger.info(f"Inserting {len(data)} vectors to {self.namespace}") @@ -290,7 +292,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): for i, d in enumerate(list_data): d["content_vector"] = embeddings[i] - if self.namespace.endswith("entities"): + if is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): data = [] for item in list_data: param = { @@ -311,7 +313,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): merge_sql = SQL_TEMPLATES["insert_entity"] await self.db.execute(merge_sql, data) - elif self.namespace.endswith("relationships"): + elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): data = [] for item in list_data: param = { @@ -470,20 +472,33 @@ class TiDBGraphStorage(BaseGraphStorage): N_T = { - "full_docs": "LIGHTRAG_DOC_FULL", - "text_chunks": "LIGHTRAG_DOC_CHUNKS", - "chunks": "LIGHTRAG_DOC_CHUNKS", - "entities": "LIGHTRAG_GRAPH_NODES", - "relationships": "LIGHTRAG_GRAPH_EDGES", + NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", + NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES", + NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", } N_ID = { - "full_docs": "doc_id", - "text_chunks": "chunk_id", - "chunks": "chunk_id", - "entities": "entity_id", - "relationships": "relation_id", + NameSpace.KV_STORE_FULL_DOCS: "doc_id", + NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id", + NameSpace.VECTOR_STORE_CHUNKS: "chunk_id", + NameSpace.VECTOR_STORE_ENTITIES: "entity_id", + NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id", } + +def namespace_to_table_name(namespace: str) -> str: + for k, v in N_T.items(): + if is_namespace(namespace, k): + return v + + +def namespace_to_id(namespace: str) -> str: + for k, v in N_ID.items(): + if is_namespace(namespace, k): + return v + + TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """ diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index abc9390c..242c6832 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -35,6 +35,8 @@ from .base import ( DocStatus, ) +from .namespace import NameSpace, make_namespace + from .prompt import GRAPH_FIELD_SEP STORAGES = { @@ -228,8 +230,13 @@ class LightRAG: self.graph_storage_cls, global_config=global_config ) + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace=self.namespace_prefix + "json_doc_status_storage", + embedding_func=None, + ) + self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), embedding_func=self.embedding_func, ) @@ -237,34 +244,33 @@ class LightRAG: # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "full_docs", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "text_chunks", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace=self.namespace_prefix + "chunk_entity_relation", + namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION), embedding_func=self.embedding_func, ) - #### # add embedding func by walter over #### self.entities_vdb = self.vector_db_storage_cls( - namespace=self.namespace_prefix + "entities", + namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( - namespace=self.namespace_prefix + "relationships", + namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( - namespace=self.namespace_prefix + "chunks", + namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS), embedding_func=self.embedding_func, ) @@ -274,7 +280,7 @@ class LightRAG: hashing_kv = self.llm_response_cache else: hashing_kv = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), embedding_func=self.embedding_func, ) @@ -289,7 +295,7 @@ class LightRAG: # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status = self.doc_status_storage_cls( - namespace=self.namespace_prefix + "doc_status", + namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), global_config=global_config, embedding_func=None, ) @@ -925,7 +931,7 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -942,7 +948,7 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -961,7 +967,7 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1002,7 +1008,7 @@ class LightRAG: global_config=asdict(self), hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1033,7 +1039,7 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_funcne, ), @@ -1049,7 +1055,7 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1068,7 +1074,7 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "llm_response_cache", + namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), global_config=asdict(self), embedding_func=self.embedding_func, ), diff --git a/lightrag/namespace.py b/lightrag/namespace.py new file mode 100644 index 00000000..ba8e3072 --- /dev/null +++ b/lightrag/namespace.py @@ -0,0 +1,25 @@ +from typing import Iterable + + +class NameSpace: + KV_STORE_FULL_DOCS = "full_docs" + KV_STORE_TEXT_CHUNKS = "text_chunks" + KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache" + + VECTOR_STORE_ENTITIES = "entities" + VECTOR_STORE_RELATIONSHIPS = "relationships" + VECTOR_STORE_CHUNKS = "chunks" + + GRAPH_STORE_CHUNK_ENTITY_RELATION = "chunk_entity_relation" + + DOC_STATUS = "doc_status" + + +def make_namespace(prefix: str, base_namespace: str): + return prefix + base_namespace + + +def is_namespace(namespace: str, base_namespace: str | Iterable[str]): + if isinstance(base_namespace, str): + return namespace.endswith(base_namespace) + return any(is_namespace(namespace, ns) for ns in base_namespace) From f5bf6a4af805d2ca93de79f74304b0e7835b410c Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 8 Feb 2025 16:06:07 +0800 Subject: [PATCH 2/2] use namespace as neo4j database name format fix --- lightrag/kg/neo4j_impl.py | 93 +++++++++++++++++++++++--------------- lightrag/kg/oracle_impl.py | 8 +++- lightrag/lightrag.py | 60 ++++++++++++++++++------ 3 files changed, 107 insertions(+), 54 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f4e18446..fe01aaf3 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,6 +1,7 @@ import asyncio import inspect import os +import re from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict import pipmaster as pm @@ -22,7 +23,7 @@ from tenacity import ( retry_if_exception_type, ) -from lightrag.utils import logger +from ..utils import logger from ..base import BaseGraphStorage @@ -45,50 +46,68 @@ class Neo4JStorage(BaseGraphStorage): PASSWORD = os.environ["NEO4J_PASSWORD"] MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) DATABASE = os.environ.get( - "NEO4J_DATABASE" - ) # If this param is None, the home database will be used. If it is not None, the specified database will be used. - self._DATABASE = DATABASE + "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) + ) self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) - _database_name = "home database" if DATABASE is None else f"database {DATABASE}" + + # Try to connect to the database with GraphDatabase.driver( URI, auth=(USERNAME, PASSWORD), max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, ) as _sync_driver: - try: - with _sync_driver.session(database=DATABASE) as session: - try: - session.run("MATCH (n) RETURN n LIMIT 0") - logger.info(f"Connected to {DATABASE} at {URI}") - except neo4jExceptions.ServiceUnavailable as e: - logger.error( - f"{DATABASE} at {URI} is not available".capitalize() - ) - raise e - except neo4jExceptions.AuthError as e: - logger.error(f"Authentication failed for {DATABASE} at {URI}") - raise e - except neo4jExceptions.ClientError as e: - if e.code == "Neo.ClientError.Database.DatabaseNotFound": - logger.info( - f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize() - ) + for database in (DATABASE, None): + self._DATABASE = database + connected = False + try: - with _sync_driver.session() as session: - session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS") - logger.info(f"{DATABASE} at {URI} created".capitalize()) - except neo4jExceptions.ClientError as e: - if ( - e.code - == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" - ): - logger.warning( - "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead." - ) - logger.error(f"Failed to create {DATABASE} at {URI}") + with _sync_driver.session(database=database) as session: + try: + session.run("MATCH (n) RETURN n LIMIT 0") + logger.info(f"Connected to {database} at {URI}") + connected = True + except neo4jExceptions.ServiceUnavailable as e: + logger.error( + f"{database} at {URI} is not available".capitalize() + ) + raise e + except neo4jExceptions.AuthError as e: + logger.error(f"Authentication failed for {database} at {URI}") raise e + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"{database} at {URI} not found. Try to create specified database.".capitalize() + ) + try: + with _sync_driver.session() as session: + session.run( + f"CREATE DATABASE `{database}` IF NOT EXISTS" + ) + logger.info(f"{database} at {URI} created".capitalize()) + connected = True + except ( + neo4jExceptions.ClientError, + neo4jExceptions.DatabaseError, + ) as e: + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ) or ( + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + ): + if database is not None: + logger.warning( + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." + ) + if database is None: + logger.error(f"Failed to create {database} at {URI}") + raise e + + if connected: + break def __post_init__(self): self._node_embed_algorithms = { @@ -117,7 +136,7 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" ) return single_result["node_exists"] @@ -133,7 +152,7 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" ) return single_result["edgeExists"] diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 32fbaa10..a1a05759 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -257,7 +257,8 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]) + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), ) params = {"workspace": self.db.workspace} res = await self.db.query(SQL, params, multirows=True) @@ -330,7 +331,9 @@ class OracleKVStorage(BaseKVStorage): return None async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace)) + SQL = SQL_TEMPLATES["change_status"].format( + table_name=namespace_to_table_name(self.namespace) + ) params = {"workspace": self.db.workspace, "id": id, "status": status} await self.db.execute(SQL, params) @@ -623,6 +626,7 @@ N_T = { NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", } + def namespace_to_table_name(namespace: str) -> str: for k, v in N_T.items(): if is_namespace(namespace, k): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 242c6832..6b925be3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -236,7 +236,9 @@ class LightRAG: ) self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), embedding_func=self.embedding_func, ) @@ -244,15 +246,21 @@ class LightRAG: # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS + ), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS + ), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION), + namespace=make_namespace( + self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION + ), embedding_func=self.embedding_func, ) #### @@ -260,17 +268,23 @@ class LightRAG: #### self.entities_vdb = self.vector_db_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES), + namespace=make_namespace( + self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES + ), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS + ), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS + ), embedding_func=self.embedding_func, ) @@ -280,7 +294,9 @@ class LightRAG: hashing_kv = self.llm_response_cache else: hashing_kv = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), embedding_func=self.embedding_func, ) @@ -931,7 +947,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -948,7 +966,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -967,7 +987,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1008,7 +1030,9 @@ class LightRAG: global_config=asdict(self), hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1039,7 +1063,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_funcne, ), @@ -1055,7 +1081,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1074,7 +1102,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ),