better handling of namespace
This commit is contained in:
@@ -19,6 +19,7 @@ from ..base import (
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
|
||||
import oracledb
|
||||
|
||||
@@ -185,7 +186,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
if self.namespace.endswith("llm_response_cache"):
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
array_res = await self.db.query(SQL, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
@@ -201,7 +202,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
"""Specifically for llm_response_cache."""
|
||||
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
|
||||
if self.namespace.endswith("llm_response_cache"):
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
array_res = await self.db.query(SQL, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
@@ -218,7 +219,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
params = {"workspace": self.db.workspace}
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
if self.namespace.endswith("llm_response_cache"):
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
modes = set()
|
||||
dict_res: dict[str, dict] = {}
|
||||
for row in res:
|
||||
@@ -256,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
||||
table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys])
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
@@ -269,7 +270,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
if self.namespace.endswith("text_chunks"):
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
@@ -302,7 +303,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
"status": item["status"],
|
||||
}
|
||||
await self.db.execute(merge_sql, _data)
|
||||
if self.namespace.endswith("full_docs"):
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||
for k, v in data.items():
|
||||
# values.clear()
|
||||
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
||||
@@ -313,7 +314,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
}
|
||||
await self.db.execute(merge_sql, _data)
|
||||
|
||||
if self.namespace.endswith("llm_response_cache"):
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
for mode, items in data.items():
|
||||
for k, v in items.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||
@@ -329,15 +330,16 @@ class OracleKVStorage(BaseKVStorage):
|
||||
return None
|
||||
|
||||
async def change_status(self, id: str, status: str):
|
||||
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
|
||||
SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace))
|
||||
params = {"workspace": self.db.workspace, "id": id, "status": status}
|
||||
await self.db.execute(SQL, params)
|
||||
|
||||
async def index_done_callback(self):
|
||||
for n in ("full_docs", "text_chunks"):
|
||||
if self.namespace.endswith(n):
|
||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||
break
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -614,13 +616,19 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"entities": "LIGHTRAG_GRAPH_NODES",
|
||||
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
||||
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
||||
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
|
||||
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
|
||||
}
|
||||
|
||||
def namespace_to_table_name(namespace: str) -> str:
|
||||
for k, v in N_T.items():
|
||||
if is_namespace(namespace, k):
|
||||
return v
|
||||
|
||||
|
||||
TABLES = {
|
||||
"LIGHTRAG_DOC_FULL": {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||
|
Reference in New Issue
Block a user