support pipeline mode

This commit is contained in:
jin
2025-01-16 12:52:37 +08:00
parent 17a2ec2bc4
commit d5ae6669ea
5 changed files with 374 additions and 323 deletions

View File

@@ -12,9 +12,6 @@ from ..base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
)
import oracledb
@@ -156,8 +153,6 @@ class OracleDB:
if data is None:
await cursor.execute(sql)
else:
# print(data)
# print(sql)
await cursor.execute(sql, data)
await connection.commit()
except Exception as e:
@@ -175,7 +170,7 @@ class OracleKVStorage(BaseKVStorage):
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num",10)
################ QUERY METHODS ################
@@ -204,12 +199,11 @@ class OracleKVStorage(BaseKVStorage):
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
@@ -228,8 +222,7 @@ class OracleKVStorage(BaseKVStorage):
dict_res[mode] = {}
for row in res:
dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()]
res = [{k: v} for k, v in dict_res.items()]
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -237,38 +230,42 @@ class OracleKVStorage(BaseKVStorage):
else:
return None
async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]:
"""Specifically for llm_response_cache."""
if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
else:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, multirows=True)
if res:
return res
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""remove duplicated"""
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
)
params = {"workspace": self.db.workspace}
try:
await self.db.query(SQL, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
res = await self.db.query(SQL, params, multirows=True)
data = None
if res:
exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys])
return data
else:
exist_keys = []
data = set([s for s in keys if s not in exist_keys])
return data
return set(keys)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
# print(self._data)
# values = []
if self.namespace == "text_chunks":
list_data = [
{
"__id__": k,
"id": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
@@ -284,33 +281,30 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# print(list_data)
merge_sql = SQL_TEMPLATES["merge_chunk"]
for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"]
data = {
"check_id": item["__id__"],
"id": item["__id__"],
_data = {
"id": item["id"],
"content": item["content"],
"workspace": self.db.workspace,
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": item["__vector__"],
"status": item["status"],
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, _data)
if self.namespace == "full_docs":
for k, v in self._data.items():
for k, v in data.items():
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"]
data = {
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, _data)
if self.namespace == "llm_response_cache":
for mode, items in data.items():
@@ -325,102 +319,20 @@ class OracleKVStorage(BaseKVStorage):
}
await self.db.execute(upsert_sql, _data)
return left_data
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(
table_name=N_T[self.namespace]
)
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass
class OracleDocStatusStorage(DocStatusStorage):
"""Oracle implementation of document status storage"""
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
pass
async def filter_keys(self, ids: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
ids = ",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if res:
existed = set([element["id"] for element in res])
return set(ids) - existed
else:
return set(ids)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
SQL = SQL_TEMPLATES["get_status_counts"]
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {}
for doc in res:
counts[doc["status"]] = doc["count"]
return counts
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status"""
SQL = SQL_TEMPLATES["get_docs_by_status"]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, True)
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
# Converting to be a dict
return {
element["id"]: DocProcessingStatus(
#content_summary=element["content_summary"],
content_summary = "",
content_length=element["CONTENT_LENGTH"],
status=element["STATUS"],
created_at=element["CREATETIME"],
updated_at=element["UPDATETIME"],
chunks_count=-1,
#chunks_count=element["chunks_count"],
)
for element in res
}
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def index_done_callback(self):
"""Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into ORACLE db!")
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
SQL = SQL_TEMPLATES["merge_doc_status"]
for k, v in data.items():
# chunks_count is optional
params = {
"workspace": self.db.workspace,
"id": k,
"content_summary": v["content_summary"],
"content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
}
await self.db.execute(SQL, params)
return data
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
@@ -466,7 +378,7 @@ class OracleGraphStorage(BaseGraphStorage):
def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################
@@ -500,7 +412,6 @@ class OracleGraphStorage(BaseGraphStorage):
"content": content,
"content_vector": content_vector,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
# self._graph.add_node(node_id, **node_data)
@@ -718,9 +629,10 @@ TABLES = {
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id varchar(256) PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
full_doc_id varchar(256),
status varchar(256),
chunk_order_index NUMBER,
tokens NUMBER,
content CLOB,
@@ -795,9 +707,9 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
@@ -808,24 +720,34 @@ SQL_TEMPLATES = {
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN INSERT
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL
@@ -838,27 +760,7 @@ SQL_TEMPLATES = {
return_value = :return_value,
cache_mode = :cache_mode,
updatetime = SYSDATE""",
"get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})",
"get_status_counts": """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""",
"get_docs_by_status": """select content_length,status,
TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime
from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""",
"merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status)
WHEN MATCHED THEN UPDATE
SET content_summary = :content_summary,
content_length = :content_length,
chunks_count = :chunks_count,
status = :status,
updatetime = SYSDATE""",
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM