update Oracle support

add cache support, fix bug
This commit is contained in:
jin
2025-01-10 11:36:28 +08:00
parent 957bcf8659
commit 85331e3fa2
5 changed files with 284 additions and 48 deletions

View File

@@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
APIKEY = "ocigenerativeai"
CHATMODEL = "cohere.command-r-plus"
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
CHUNK_TOKEN_SIZE = 1024
MAX_TOKENS = 4000
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -86,27 +87,49 @@ async def main():
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
working_dir=WORKING_DIR,
entity_extract_max_gleaning = 1,
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90},
enable_llm_cache_for_entity_extract = True,
chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size = MAX_TOKENS,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
max_token_size=500,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
),
graph_storage = "OracleGraphStorage",
kv_storage = "OracleKVStorage",
vector_storage="OracleVectorDBStorage",
doc_status_storage="OracleDocStatusStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.key_string_value_json_storage_cls.db = oracle_db
rag.vector_db_storage_cls.db = oracle_db
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
rag.graph_storage_cls.db = oracle_db
rag.doc_status_storage_cls.db = oracle_db
rag.doc_status.db = oracle_db
rag.full_docs.db = oracle_db
rag.text_chunks.db = oracle_db
rag.llm_response_cache.db = oracle_db
rag.key_string_value_json_storage_cls.db = oracle_db
rag.chunks_vdb.db = oracle_db
rag.relationships_vdb.db = oracle_db
rag.entities_vdb.db = oracle_db
rag.graph_storage_cls.db = oracle_db
rag.chunk_entity_relation_graph.db = oracle_db
rag.llm_response_cache.db = oracle_db
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())

View File

@@ -3,7 +3,7 @@ import asyncio
# import html
# import os
from dataclasses import dataclass
from typing import Union
from typing import Union, List, Dict, Set, Any, Tuple
import numpy as np
import array
@@ -12,6 +12,9 @@ from ..base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
)
import oracledb
@@ -167,6 +170,9 @@ class OracleDB:
@dataclass
class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -174,28 +180,56 @@ class OracleKVStorage(BaseKVStorage):
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
"""get doc_full data based on id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
# print("get_by_id:"+SQL)
res = await self.db.query(SQL, params)
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
else:
res = await self.db.query(SQL, params)
if res:
data = res # {"data":res}
# print (data)
return data
return res
else:
return None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
"""get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
# print("get_by_ids:"+SQL)
# print(params)
res = await self.db.query(SQL, params, multirows=True)
if "llm_response_cache" == self.namespace:
modes = set()
dict_res: dict[str, dict] = {}
for row in res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in res:
dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()]
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -204,7 +238,7 @@ class OracleKVStorage(BaseKVStorage):
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
"""remove duplicated"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
)
@@ -271,13 +305,26 @@ class OracleKVStorage(BaseKVStorage):
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"]
data = {
"check_id": k,
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
if self.namespace == "llm_response_cache":
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"cache_mode": mode,
}
await self.db.execute(upsert_sql, _data)
return left_data
async def index_done_callback(self):
@@ -285,8 +332,99 @@ class OracleKVStorage(BaseKVStorage):
logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass
class OracleDocStatusStorage(DocStatusStorage):
"""Oracle implementation of document status storage"""
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
pass
async def filter_keys(self, ids: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
ids = ",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if res:
existed = set([element["id"] for element in res])
return set(ids) - existed
else:
return set(ids)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
SQL = SQL_TEMPLATES["get_status_counts"]
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {}
for doc in res:
counts[doc["status"]] = doc["count"]
return counts
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status"""
SQL = SQL_TEMPLATES["get_docs_by_status"]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, True)
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
# Converting to be a dict
return {
element["id"]: DocProcessingStatus(
#content_summary=element["content_summary"],
content_summary = "",
content_length=element["CONTENT_LENGTH"],
status=element["STATUS"],
created_at=element["CREATETIME"],
updated_at=element["UPDATETIME"],
chunks_count=-1,
#chunks_count=element["chunks_count"],
)
for element in res
}
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def index_done_callback(self):
"""Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into ORACLE db!")
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
SQL = SQL_TEMPLATES["merge_doc_status"]
for k, v in data.items():
# chunks_count is optional
params = {
"workspace": self.db.workspace,
"id": k,
"content_summary": v["content_summary"],
"content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
}
await self.db.execute(SQL, params)
return data
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
db: OracleDB = None
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
@@ -564,13 +702,18 @@ N_T = {
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
id varchar(256)PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
doc_name varchar(1024),
content CLOB,
meta JSON,
content_summary varchar(1024),
content_length NUMBER,
status varchar(256),
chunks_count NUMBER,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL
updatetime TIMESTAMP DEFAULT NULL,
error varchar(4096)
)"""
},
"LIGHTRAG_DOC_CHUNKS": {
@@ -618,10 +761,16 @@ TABLES = {
},
"LIGHTRAG_LLM_CACHE": {
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
id varchar(256) PRIMARY KEY,
send clob,
return clob,
model varchar(1024),
id varchar(256) PRIMARY KEY,
workspace varchar(1024),
cache_mode varchar(256),
model_name varchar(256),
original_prompt clob,
return_value clob,
embedding CLOB,
embedding_shape NUMBER,
embedding_min NUMBER,
embedding_max NUMBER,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL
)"""
@@ -647,22 +796,70 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)
""",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL
ON (a.id = :id)
WHEN NOT MATCHED THEN
INSERT (workspace,id,original_prompt,return_value,cache_mode)
VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
WHEN MATCHED THEN UPDATE
SET original_prompt = :original_prompt,
return_value = :return_value,
cache_mode = :cache_mode,
updatetime = SYSDATE""",
"get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})",
"get_status_counts": """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""",
"get_docs_by_status": """select content_length,status,
TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime
from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""",
"merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status)
WHEN MATCHED THEN UPDATE
SET content_summary = :content_summary,
content_length = :content_length,
chunks_count = :chunks_count,
status = :status,
updatetime = SYSDATE""",
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
@@ -714,16 +911,22 @@ SQL_TEMPLATES = {
COLUMNS (a.name as source_name,b.name as target_name))""",
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
USING DUAL
ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
ON (a.workspace=:workspace and a.name=:name)
WHEN NOT MATCHED THEN
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
USING DUAL
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"get_all_nodes": """WITH t0 AS (
SELECT name AS id, entity_type AS label, entity_type, description,
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids

View File

@@ -79,6 +79,7 @@ Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage")
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
@@ -290,6 +291,7 @@ class LightRAG:
# kv storage
"JsonKVStorage": JsonKVStorage,
"OracleKVStorage": OracleKVStorage,
"OracleDocStatusStorage":OracleDocStatusStorage,
"MongoKVStorage": MongoKVStorage,
"TiDBKVStorage": TiDBKVStorage,
# vector storage

View File

@@ -59,13 +59,15 @@ async def _handle_entity_relation_summary(
description: str,
global_config: dict,
) -> str:
"""Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description.
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"])
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary
@@ -139,6 +141,7 @@ async def _merge_nodes_then_upsert(
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = []
already_source_ids = []
already_description = []
@@ -319,7 +322,7 @@ async def extract_entities(
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages)
history = json.dumps(history_messages,ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
@@ -351,6 +354,11 @@ async def extract_entities(
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
""""Prpocess a single chunk
Args:
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
"""
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]

View File

@@ -36,7 +36,7 @@ logger = logging.getLogger("lightrag")
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
@@ -473,7 +473,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])