better handling of namespace

This commit is contained in:
ArnoChen
2025-02-08 16:05:59 +08:00
parent e787d92a0c
commit 3f845e9e53
8 changed files with 156 additions and 93 deletions

View File

@@ -11,6 +11,7 @@ from dotenv import load_dotenv
from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage
from lightrag.storage import JsonKVStorage from lightrag.storage import JsonKVStorage
from lightrag.namespace import NameSpace
load_dotenv() load_dotenv()
ROOT_DIR = os.environ.get("ROOT_DIR") ROOT_DIR = os.environ.get("ROOT_DIR")
@@ -39,14 +40,14 @@ async def copy_from_postgres_to_json():
await postgres_db.initdb() await postgres_db.initdb()
from_llm_response_cache = PGKVStorage( from_llm_response_cache = PGKVStorage(
namespace="llm_response_cache", namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
global_config={"embedding_batch_num": 6}, global_config={"embedding_batch_num": 6},
embedding_func=None, embedding_func=None,
db=postgres_db, db=postgres_db,
) )
to_llm_response_cache = JsonKVStorage( to_llm_response_cache = JsonKVStorage(
namespace="llm_response_cache", namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
global_config={"working_dir": WORKING_DIR}, global_config={"working_dir": WORKING_DIR},
embedding_func=None, embedding_func=None,
) )
@@ -72,13 +73,13 @@ async def copy_from_json_to_postgres():
await postgres_db.initdb() await postgres_db.initdb()
from_llm_response_cache = JsonKVStorage( from_llm_response_cache = JsonKVStorage(
namespace="llm_response_cache", namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
global_config={"working_dir": WORKING_DIR}, global_config={"working_dir": WORKING_DIR},
embedding_func=None, embedding_func=None,
) )
to_llm_response_cache = PGKVStorage( to_llm_response_cache = PGKVStorage(
namespace="llm_response_cache", namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
global_config={"embedding_batch_num": 6}, global_config={"embedding_batch_num": 6},
embedding_func=None, embedding_func=None,
db=postgres_db, db=postgres_db,

View File

@@ -13,10 +13,10 @@ if not pm.is_installed("motor"):
from pymongo import MongoClient from pymongo import MongoClient
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from typing import Union, List, Tuple from typing import Union, List, Tuple
from lightrag.utils import logger
from lightrag.base import BaseKVStorage from ..utils import logger
from lightrag.base import BaseGraphStorage from ..base import BaseKVStorage, BaseGraphStorage
from ..namespace import NameSpace, is_namespace
@dataclass @dataclass
@@ -52,7 +52,7 @@ class MongoKVStorage(BaseKVStorage):
return set([s for s in data if s not in existing_ids]) return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items(): for mode, items in data.items():
for k, v in tqdm_async(items.items(), desc="Upserting"): for k, v in tqdm_async(items.items(), desc="Upserting"):
key = f"{mode}_{k}" key = f"{mode}_{k}"
@@ -69,7 +69,7 @@ class MongoKVStorage(BaseKVStorage):
return data return data
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
res = {} res = {}
v = self._data.find_one({"_id": mode + "_" + id}) v = self._data.find_one({"_id": mode + "_" + id})
if v: if v:

View File

@@ -19,6 +19,7 @@ from ..base import (
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, BaseVectorStorage,
) )
from ..namespace import NameSpace, is_namespace
import oracledb import oracledb
@@ -185,7 +186,7 @@ class OracleKVStorage(BaseKVStorage):
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
# print("get_by_id:"+SQL) # print("get_by_id:"+SQL)
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(SQL, params, multirows=True) array_res = await self.db.query(SQL, params, multirows=True)
res = {} res = {}
for row in array_res: for row in array_res:
@@ -201,7 +202,7 @@ class OracleKVStorage(BaseKVStorage):
"""Specifically for llm_response_cache.""" """Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(SQL, params, multirows=True) array_res = await self.db.query(SQL, params, multirows=True)
res = {} res = {}
for row in array_res: for row in array_res:
@@ -218,7 +219,7 @@ class OracleKVStorage(BaseKVStorage):
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
# print("get_by_ids:"+SQL) # print("get_by_ids:"+SQL)
res = await self.db.query(SQL, params, multirows=True) res = await self.db.query(SQL, params, multirows=True)
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
modes = set() modes = set()
dict_res: dict[str, dict] = {} dict_res: dict[str, dict] = {}
for row in res: for row in res:
@@ -256,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys])
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, multirows=True) res = await self.db.query(SQL, params, multirows=True)
@@ -269,7 +270,7 @@ class OracleKVStorage(BaseKVStorage):
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
if self.namespace.endswith("text_chunks"): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [ list_data = [
{ {
"id": k, "id": k,
@@ -302,7 +303,7 @@ class OracleKVStorage(BaseKVStorage):
"status": item["status"], "status": item["status"],
} }
await self.db.execute(merge_sql, _data) await self.db.execute(merge_sql, _data)
if self.namespace.endswith("full_docs"): if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
for k, v in data.items(): for k, v in data.items():
# values.clear() # values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"] merge_sql = SQL_TEMPLATES["merge_doc_full"]
@@ -313,7 +314,7 @@ class OracleKVStorage(BaseKVStorage):
} }
await self.db.execute(merge_sql, _data) await self.db.execute(merge_sql, _data)
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items(): for mode, items in data.items():
for k, v in items.items(): for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
@@ -329,15 +330,16 @@ class OracleKVStorage(BaseKVStorage):
return None return None
async def change_status(self, id: str, status: str): async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace))
params = {"workspace": self.db.workspace, "id": id, "status": status} params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params) await self.db.execute(SQL, params)
async def index_done_callback(self): async def index_done_callback(self):
for n in ("full_docs", "text_chunks"): if is_namespace(
if self.namespace.endswith(n): self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into oracle db!") logger.info("full doc and chunk data had been saved into oracle db!")
break
@dataclass @dataclass
@@ -614,13 +616,19 @@ class OracleGraphStorage(BaseGraphStorage):
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_GRAPH_NODES", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
"relationships": "LIGHTRAG_GRAPH_EDGES", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
} }
def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items():
if is_namespace(namespace, k):
return v
TABLES = { TABLES = {
"LIGHTRAG_DOC_FULL": { "LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (

View File

@@ -32,6 +32,7 @@ from ..base import (
BaseGraphStorage, BaseGraphStorage,
T, T,
) )
from ..namespace import NameSpace, is_namespace
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
import asyncio.windows_events import asyncio.windows_events
@@ -187,7 +188,7 @@ class PGKVStorage(BaseKVStorage):
"""Get doc_full data by id.""" """Get doc_full data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace] sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True) array_res = await self.db.query(sql, params, multirows=True)
res = {} res = {}
for row in array_res: for row in array_res:
@@ -203,7 +204,7 @@ class PGKVStorage(BaseKVStorage):
"""Specifically for llm_response_cache.""" """Specifically for llm_response_cache."""
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, mode: mode, "id": id} params = {"workspace": self.db.workspace, mode: mode, "id": id}
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True) array_res = await self.db.query(sql, params, multirows=True)
res = {} res = {}
for row in array_res: for row in array_res:
@@ -219,7 +220,7 @@ class PGKVStorage(BaseKVStorage):
ids=",".join([f"'{id}'" for id in ids]) ids=",".join([f"'{id}'" for id in ids])
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True) array_res = await self.db.query(sql, params, multirows=True)
modes = set() modes = set()
dict_res: dict[str, dict] = {} dict_res: dict[str, dict] = {}
@@ -239,7 +240,7 @@ class PGKVStorage(BaseKVStorage):
return None return None
async def all_keys(self) -> list[dict]: async def all_keys(self) -> list[dict]:
if self.namespace.endswith("llm_response_cache"): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
sql = "select workspace,mode,id from lightrag_llm_cache" sql = "select workspace,mode,id from lightrag_llm_cache"
res = await self.db.query(sql, multirows=True) res = await self.db.query(sql, multirows=True)
return res return res
@@ -251,7 +252,7 @@ class PGKVStorage(BaseKVStorage):
async def filter_keys(self, keys: List[str]) -> Set[str]: async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( sql = SQL_TEMPLATES["filter_keys"].format(
table_name=NAMESPACE_TABLE_MAP[self.namespace], table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]), ids=",".join([f"'{id}'" for id in keys]),
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
@@ -270,9 +271,9 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: Dict[str, dict]): async def upsert(self, data: Dict[str, dict]):
if self.namespace.endswith("text_chunks"): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass pass
elif self.namespace.endswith("full_docs"): elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
for k, v in data.items(): for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"] upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
_data = { _data = {
@@ -281,7 +282,7 @@ class PGKVStorage(BaseKVStorage):
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
elif self.namespace.endswith("llm_response_cache"): elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items(): for mode, items in data.items():
for k, v in items.items(): for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
@@ -296,12 +297,11 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self): async def index_done_callback(self):
for n in ("full_docs", "text_chunks"): if is_namespace(
if self.namespace.endswith(n): self.namespace,
logger.info( (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
"full doc and chunk data had been saved into postgresql db!" ):
) logger.info("full doc and chunk data had been saved into postgresql db!")
break
@dataclass @dataclass
@@ -393,11 +393,11 @@ class PGVectorStorage(BaseVectorStorage):
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
for item in list_data: for item in list_data:
if self.namespace.endswith("chunks"): if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
upsert_sql, data = self._upsert_chunks(item) upsert_sql, data = self._upsert_chunks(item)
elif self.namespace.endswith("entities"): elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
upsert_sql, data = self._upsert_entities(item) upsert_sql, data = self._upsert_entities(item)
elif self.namespace.endswith("relationships"): elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
upsert_sql, data = self._upsert_relationships(item) upsert_sql, data = self._upsert_relationships(item)
else: else:
raise ValueError(f"{self.namespace} is not supported") raise ValueError(f"{self.namespace} is not supported")
@@ -1027,16 +1027,22 @@ class PGGraphStorage(BaseGraphStorage):
NAMESPACE_TABLE_MAP = { NAMESPACE_TABLE_MAP = {
"full_docs": "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_VDB_ENTITY", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
"relationships": "LIGHTRAG_VDB_RELATION", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
"doc_status": "LIGHTRAG_DOC_STATUS", NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
"llm_response_cache": "LIGHTRAG_LLM_CACHE", NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
} }
def namespace_to_table_name(namespace: str) -> str:
for k, v in NAMESPACE_TABLE_MAP.items():
if is_namespace(namespace, k):
return v
TABLES = { TABLES = {
"LIGHTRAG_DOC_FULL": { "LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (

View File

@@ -12,7 +12,9 @@ if not pm.is_installed("asyncpg"):
import asyncpg import asyncpg
import psycopg import psycopg
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..namespace import NameSpace
DB = "rag" DB = "rag"
USER = "rag" USER = "rag"
@@ -76,7 +78,7 @@ db = PostgreSQLDB(
async def query_with_age(): async def query_with_age():
await db.initdb() await db.initdb()
graph = PGGraphStorage( graph = PGGraphStorage(
namespace="chunk_entity_relation", namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={}, global_config={},
embedding_func=None, embedding_func=None,
) )
@@ -92,7 +94,7 @@ async def query_with_age():
async def create_edge_with_age(): async def create_edge_with_age():
await db.initdb() await db.initdb()
graph = PGGraphStorage( graph = PGGraphStorage(
namespace="chunk_entity_relation", namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={}, global_config={},
embedding_func=None, embedding_func=None,
) )

View File

@@ -14,8 +14,9 @@ if not pm.is_installed("sqlalchemy"):
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from tqdm import tqdm from tqdm import tqdm
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from lightrag.utils import logger from ..utils import logger
from ..namespace import NameSpace, is_namespace
class TiDB(object): class TiDB(object):
@@ -138,8 +139,8 @@ class TiDBKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容""" """过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], table_name=namespace_to_table_name(self.namespace),
id_field=N_ID[self.namespace], id_field=namespace_to_id(self.namespace),
ids=",".join([f"'{id}'" for id in keys]), ids=",".join([f"'{id}'" for id in keys]),
) )
try: try:
@@ -160,7 +161,7 @@ class TiDBKVStorage(BaseKVStorage):
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
if self.namespace.endswith("text_chunks"): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [ list_data = [
{ {
"__id__": k, "__id__": k,
@@ -196,7 +197,7 @@ class TiDBKVStorage(BaseKVStorage):
) )
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
if self.namespace.endswith("full_docs"): if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
merge_sql = SQL_TEMPLATES["upsert_doc_full"] merge_sql = SQL_TEMPLATES["upsert_doc_full"]
data = [] data = []
for k, v in self._data.items(): for k, v in self._data.items():
@@ -211,10 +212,11 @@ class TiDBKVStorage(BaseKVStorage):
return left_data return left_data
async def index_done_callback(self): async def index_done_callback(self):
for n in ("full_docs", "text_chunks"): if is_namespace(
if self.namespace.endswith(n): self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into TiDB db!") logger.info("full doc and chunk data had been saved into TiDB db!")
break
@dataclass @dataclass
@@ -260,7 +262,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
return [] return []
if self.namespace.endswith("chunks"): if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
return [] return []
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
@@ -290,7 +292,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["content_vector"] = embeddings[i] d["content_vector"] = embeddings[i]
if self.namespace.endswith("entities"): if is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
data = [] data = []
for item in list_data: for item in list_data:
param = { param = {
@@ -311,7 +313,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
merge_sql = SQL_TEMPLATES["insert_entity"] merge_sql = SQL_TEMPLATES["insert_entity"]
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
elif self.namespace.endswith("relationships"): elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
data = [] data = []
for item in list_data: for item in list_data:
param = { param = {
@@ -470,20 +472,33 @@ class TiDBGraphStorage(BaseGraphStorage):
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_GRAPH_NODES", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
"relationships": "LIGHTRAG_GRAPH_EDGES", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
} }
N_ID = { N_ID = {
"full_docs": "doc_id", NameSpace.KV_STORE_FULL_DOCS: "doc_id",
"text_chunks": "chunk_id", NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id",
"chunks": "chunk_id", NameSpace.VECTOR_STORE_CHUNKS: "chunk_id",
"entities": "entity_id", NameSpace.VECTOR_STORE_ENTITIES: "entity_id",
"relationships": "relation_id", NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id",
} }
def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items():
if is_namespace(namespace, k):
return v
def namespace_to_id(namespace: str) -> str:
for k, v in N_ID.items():
if is_namespace(namespace, k):
return v
TABLES = { TABLES = {
"LIGHTRAG_DOC_FULL": { "LIGHTRAG_DOC_FULL": {
"ddl": """ "ddl": """

View File

@@ -35,6 +35,8 @@ from .base import (
DocStatus, DocStatus,
) )
from .namespace import NameSpace, make_namespace
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
STORAGES = { STORAGES = {
@@ -228,8 +230,13 @@ class LightRAG:
self.graph_storage_cls, global_config=global_config self.graph_storage_cls, global_config=global_config
) )
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "json_doc_status_storage",
embedding_func=None,
)
self.llm_response_cache = self.key_string_value_json_storage_cls( self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -237,34 +244,33 @@ class LightRAG:
# add embedding func by walter # add embedding func by walter
#### ####
self.full_docs = self.key_string_value_json_storage_cls( self.full_docs = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "full_docs", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.text_chunks = self.key_string_value_json_storage_cls( self.text_chunks = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "text_chunks", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph = self.graph_storage_cls( self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace=self.namespace_prefix + "chunk_entity_relation", namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
#### ####
# add embedding func by walter over # add embedding func by walter over
#### ####
self.entities_vdb = self.vector_db_storage_cls( self.entities_vdb = self.vector_db_storage_cls(
namespace=self.namespace_prefix + "entities", namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name"}, meta_fields={"entity_name"},
) )
self.relationships_vdb = self.vector_db_storage_cls( self.relationships_vdb = self.vector_db_storage_cls(
namespace=self.namespace_prefix + "relationships", namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
self.chunks_vdb = self.vector_db_storage_cls( self.chunks_vdb = self.vector_db_storage_cls(
namespace=self.namespace_prefix + "chunks", namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -274,7 +280,7 @@ class LightRAG:
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
else: else:
hashing_kv = self.key_string_value_json_storage_cls( hashing_kv = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -289,7 +295,7 @@ class LightRAG:
# Initialize document status storage # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls( self.doc_status = self.doc_status_storage_cls(
namespace=self.namespace_prefix + "doc_status", namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config, global_config=global_config,
embedding_func=None, embedding_func=None,
) )
@@ -925,7 +931,7 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -942,7 +948,7 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -961,7 +967,7 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -1002,7 +1008,7 @@ class LightRAG:
global_config=asdict(self), global_config=asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls( or self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -1033,7 +1039,7 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_funcne, embedding_func=self.embedding_funcne,
), ),
@@ -1049,7 +1055,7 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -1068,7 +1074,7 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "llm_response_cache", namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),

25
lightrag/namespace.py Normal file
View File

@@ -0,0 +1,25 @@
from typing import Iterable
class NameSpace:
KV_STORE_FULL_DOCS = "full_docs"
KV_STORE_TEXT_CHUNKS = "text_chunks"
KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache"
VECTOR_STORE_ENTITIES = "entities"
VECTOR_STORE_RELATIONSHIPS = "relationships"
VECTOR_STORE_CHUNKS = "chunks"
GRAPH_STORE_CHUNK_ENTITY_RELATION = "chunk_entity_relation"
DOC_STATUS = "doc_status"
def make_namespace(prefix: str, base_namespace: str):
return prefix + base_namespace
def is_namespace(namespace: str, base_namespace: str | Iterable[str]):
if isinstance(base_namespace, str):
return namespace.endswith(base_namespace)
return any(is_namespace(namespace, ns) for ns in base_namespace)