better handling of namespace

This commit is contained in:
ArnoChen
2025-02-08 16:05:59 +08:00
parent e787d92a0c
commit 3f845e9e53
8 changed files with 156 additions and 93 deletions

View File

@@ -13,10 +13,10 @@ if not pm.is_installed("motor"):
from pymongo import MongoClient
from motor.motor_asyncio import AsyncIOMotorClient
from typing import Union, List, Tuple
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
from lightrag.base import BaseGraphStorage
from ..utils import logger
from ..base import BaseKVStorage, BaseGraphStorage
from ..namespace import NameSpace, is_namespace
@dataclass
@@ -52,7 +52,7 @@ class MongoKVStorage(BaseKVStorage):
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict]):
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items():
for k, v in tqdm_async(items.items(), desc="Upserting"):
key = f"{mode}_{k}"
@@ -69,7 +69,7 @@ class MongoKVStorage(BaseKVStorage):
return data
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
res = {}
v = self._data.find_one({"_id": mode + "_" + id})
if v:

View File

@@ -19,6 +19,7 @@ from ..base import (
BaseKVStorage,
BaseVectorStorage,
)
from ..namespace import NameSpace, is_namespace
import oracledb
@@ -185,7 +186,7 @@ class OracleKVStorage(BaseKVStorage):
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
# print("get_by_id:"+SQL)
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
@@ -201,7 +202,7 @@ class OracleKVStorage(BaseKVStorage):
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
@@ -218,7 +219,7 @@ class OracleKVStorage(BaseKVStorage):
params = {"workspace": self.db.workspace}
# print("get_by_ids:"+SQL)
res = await self.db.query(SQL, params, multirows=True)
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
modes = set()
dict_res: dict[str, dict] = {}
for row in res:
@@ -256,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys])
)
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, multirows=True)
@@ -269,7 +270,7 @@ class OracleKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]):
if self.namespace.endswith("text_chunks"):
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [
{
"id": k,
@@ -302,7 +303,7 @@ class OracleKVStorage(BaseKVStorage):
"status": item["status"],
}
await self.db.execute(merge_sql, _data)
if self.namespace.endswith("full_docs"):
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
for k, v in data.items():
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"]
@@ -313,7 +314,7 @@ class OracleKVStorage(BaseKVStorage):
}
await self.db.execute(merge_sql, _data)
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
@@ -329,15 +330,16 @@ class OracleKVStorage(BaseKVStorage):
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace))
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
async def index_done_callback(self):
for n in ("full_docs", "text_chunks"):
if self.namespace.endswith(n):
logger.info("full doc and chunk data had been saved into oracle db!")
break
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass
@@ -614,13 +616,19 @@ class OracleGraphStorage(BaseGraphStorage):
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_GRAPH_NODES",
"relationships": "LIGHTRAG_GRAPH_EDGES",
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
}
def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items():
if is_namespace(namespace, k):
return v
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (

View File

@@ -32,6 +32,7 @@ from ..base import (
BaseGraphStorage,
T,
)
from ..namespace import NameSpace, is_namespace
if sys.platform.startswith("win"):
import asyncio.windows_events
@@ -187,7 +188,7 @@ class PGKVStorage(BaseKVStorage):
"""Get doc_full data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
@@ -203,7 +204,7 @@ class PGKVStorage(BaseKVStorage):
"""Specifically for llm_response_cache."""
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, mode: mode, "id": id}
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
@@ -219,7 +220,7 @@ class PGKVStorage(BaseKVStorage):
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
modes = set()
dict_res: dict[str, dict] = {}
@@ -239,7 +240,7 @@ class PGKVStorage(BaseKVStorage):
return None
async def all_keys(self) -> list[dict]:
if self.namespace.endswith("llm_response_cache"):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
sql = "select workspace,mode,id from lightrag_llm_cache"
res = await self.db.query(sql, multirows=True)
return res
@@ -251,7 +252,7 @@ class PGKVStorage(BaseKVStorage):
async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=NAMESPACE_TABLE_MAP[self.namespace],
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.db.workspace}
@@ -270,9 +271,9 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: Dict[str, dict]):
if self.namespace.endswith("text_chunks"):
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass
elif self.namespace.endswith("full_docs"):
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
_data = {
@@ -281,7 +282,7 @@ class PGKVStorage(BaseKVStorage):
"workspace": self.db.workspace,
}
await self.db.execute(upsert_sql, _data)
elif self.namespace.endswith("llm_response_cache"):
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
@@ -296,12 +297,11 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self):
for n in ("full_docs", "text_chunks"):
if self.namespace.endswith(n):
logger.info(
"full doc and chunk data had been saved into postgresql db!"
)
break
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into postgresql db!")
@dataclass
@@ -393,11 +393,11 @@ class PGVectorStorage(BaseVectorStorage):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
for item in list_data:
if self.namespace.endswith("chunks"):
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
upsert_sql, data = self._upsert_chunks(item)
elif self.namespace.endswith("entities"):
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
upsert_sql, data = self._upsert_entities(item)
elif self.namespace.endswith("relationships"):
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
upsert_sql, data = self._upsert_relationships(item)
else:
raise ValueError(f"{self.namespace} is not supported")
@@ -1027,16 +1027,22 @@ class PGGraphStorage(BaseGraphStorage):
NAMESPACE_TABLE_MAP = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_VDB_ENTITY",
"relationships": "LIGHTRAG_VDB_RELATION",
"doc_status": "LIGHTRAG_DOC_STATUS",
"llm_response_cache": "LIGHTRAG_LLM_CACHE",
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
}
def namespace_to_table_name(namespace: str) -> str:
for k, v in NAMESPACE_TABLE_MAP.items():
if is_namespace(namespace, k):
return v
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (

View File

@@ -12,7 +12,9 @@ if not pm.is_installed("asyncpg"):
import asyncpg
import psycopg
from psycopg_pool import AsyncConnectionPool
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..namespace import NameSpace
DB = "rag"
USER = "rag"
@@ -76,7 +78,7 @@ db = PostgreSQLDB(
async def query_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace="chunk_entity_relation",
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
@@ -92,7 +94,7 @@ async def query_with_age():
async def create_edge_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace="chunk_entity_relation",
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)

View File

@@ -14,8 +14,9 @@ if not pm.is_installed("sqlalchemy"):
from sqlalchemy import create_engine, text
from tqdm import tqdm
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from lightrag.utils import logger
from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from ..utils import logger
from ..namespace import NameSpace, is_namespace
class TiDB(object):
@@ -138,8 +139,8 @@ class TiDBKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace],
id_field=N_ID[self.namespace],
table_name=namespace_to_table_name(self.namespace),
id_field=namespace_to_id(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
try:
@@ -160,7 +161,7 @@ class TiDBKVStorage(BaseKVStorage):
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if self.namespace.endswith("text_chunks"):
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [
{
"__id__": k,
@@ -196,7 +197,7 @@ class TiDBKVStorage(BaseKVStorage):
)
await self.db.execute(merge_sql, data)
if self.namespace.endswith("full_docs"):
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
data = []
for k, v in self._data.items():
@@ -211,10 +212,11 @@ class TiDBKVStorage(BaseKVStorage):
return left_data
async def index_done_callback(self):
for n in ("full_docs", "text_chunks"):
if self.namespace.endswith(n):
logger.info("full doc and chunk data had been saved into TiDB db!")
break
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into TiDB db!")
@dataclass
@@ -260,7 +262,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
if self.namespace.endswith("chunks"):
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
return []
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
@@ -290,7 +292,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
for i, d in enumerate(list_data):
d["content_vector"] = embeddings[i]
if self.namespace.endswith("entities"):
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
data = []
for item in list_data:
param = {
@@ -311,7 +313,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
merge_sql = SQL_TEMPLATES["insert_entity"]
await self.db.execute(merge_sql, data)
elif self.namespace.endswith("relationships"):
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
data = []
for item in list_data:
param = {
@@ -470,20 +472,33 @@ class TiDBGraphStorage(BaseGraphStorage):
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_GRAPH_NODES",
"relationships": "LIGHTRAG_GRAPH_EDGES",
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
}
N_ID = {
"full_docs": "doc_id",
"text_chunks": "chunk_id",
"chunks": "chunk_id",
"entities": "entity_id",
"relationships": "relation_id",
NameSpace.KV_STORE_FULL_DOCS: "doc_id",
NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id",
NameSpace.VECTOR_STORE_CHUNKS: "chunk_id",
NameSpace.VECTOR_STORE_ENTITIES: "entity_id",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id",
}
def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items():
if is_namespace(namespace, k):
return v
def namespace_to_id(namespace: str) -> str:
for k, v in N_ID.items():
if is_namespace(namespace, k):
return v
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """