Some enhancements:
- Enable the llm_cache storage to support get_by_mode_and_id, to improve the performance for using real KV server - Provide an option for the developers to cache the LLM response when extracting entities for a document. Solving the paint point that sometimes the process failed, the processed chunks we need to call LLM again, money and time wasted. With the new option (by default not enabled) enabling, we can cache that result, can significantly save the time and money for beginners.
This commit is contained in:
@@ -151,7 +151,10 @@ class PostgreSQLDB:
|
||||
try:
|
||||
await conn.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await conn.execute(f"""select create_graph('{graph_name}')""")
|
||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||
except (
|
||||
asyncpg.exceptions.InvalidSchemaNameError,
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@@ -160,7 +163,6 @@ class PGKVStorage(BaseKVStorage):
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
################ QUERY METHODS ################
|
||||
@@ -181,6 +183,19 @@ class PGKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, mode: mode, "id": id}
|
||||
if "llm_response_cache" == self.namespace:
|
||||
array_res = await self.db.query(sql, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
|
||||
"""Get doc_chunks data by id"""
|
||||
@@ -229,33 +244,30 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if self.namespace == "text_chunks":
|
||||
pass
|
||||
elif self.namespace == "full_docs":
|
||||
for k, v in self._data.items():
|
||||
for k, v in data.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
|
||||
data = {
|
||||
_data = {
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
await self.db.execute(upsert_sql, data)
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
elif self.namespace == "llm_response_cache":
|
||||
for mode, items in self._data.items():
|
||||
for mode, items in data.items():
|
||||
for k, v in items.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||
data = {
|
||||
_data = {
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"original_prompt": v["original_prompt"],
|
||||
"return": v["return"],
|
||||
"return_value": v["return"],
|
||||
"mode": mode,
|
||||
}
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
return left_data
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
if self.namespace in ["full_docs", "text_chunks"]:
|
||||
@@ -977,9 +989,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
edge_properties = edge_data
|
||||
logger.info(
|
||||
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
|
||||
)
|
||||
|
||||
query = """MATCH (source:`{src_label}`)
|
||||
WITH source
|
||||
@@ -1028,8 +1037,8 @@ TABLES = {
|
||||
doc_name VARCHAR(1024),
|
||||
content TEXT,
|
||||
meta JSONB,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
@@ -1042,8 +1051,8 @@ TABLES = {
|
||||
tokens INTEGER,
|
||||
content TEXT,
|
||||
content_vector VECTOR,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
@@ -1054,8 +1063,8 @@ TABLES = {
|
||||
entity_name VARCHAR(255),
|
||||
content TEXT,
|
||||
content_vector VECTOR,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
@@ -1067,8 +1076,8 @@ TABLES = {
|
||||
target_id VARCHAR(256),
|
||||
content TEXT,
|
||||
content_vector VECTOR,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
@@ -1078,10 +1087,10 @@ TABLES = {
|
||||
id varchar(255) NOT NULL,
|
||||
mode varchar(32) NOT NULL,
|
||||
original_prompt TEXT,
|
||||
return TEXT,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
|
||||
return_value TEXT,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_DOC_STATUS": {
|
||||
@@ -1109,9 +1118,12 @@ SQL_TEMPLATES = {
|
||||
chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
|
||||
""",
|
||||
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||
""",
|
||||
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
|
||||
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
||||
""",
|
||||
@@ -1119,22 +1131,22 @@ SQL_TEMPLATES = {
|
||||
chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
||||
""",
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
|
||||
""",
|
||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET content = $2, updatetime = CURRENT_TIMESTAMP
|
||||
SET content = $2, update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,"return",mode)
|
||||
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
ON CONFLICT (workspace,mode,id) DO UPDATE
|
||||
SET original_prompt = EXCLUDED.original_prompt,
|
||||
"return"=EXCLUDED."return",
|
||||
return_value=EXCLUDED.return_value,
|
||||
mode=EXCLUDED.mode,
|
||||
updatetime = CURRENT_TIMESTAMP
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
||||
chunk_order_index, full_doc_id, content, content_vector)
|
||||
@@ -1145,7 +1157,7 @@ SQL_TEMPLATES = {
|
||||
full_doc_id=EXCLUDED.full_doc_id,
|
||||
content = EXCLUDED.content,
|
||||
content_vector=EXCLUDED.content_vector,
|
||||
updatetime = CURRENT_TIMESTAMP
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -1153,7 +1165,7 @@ SQL_TEMPLATES = {
|
||||
SET entity_name=EXCLUDED.entity_name,
|
||||
content=EXCLUDED.content,
|
||||
content_vector=EXCLUDED.content_vector,
|
||||
updatetime=CURRENT_TIMESTAMP
|
||||
update_time=CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
||||
target_id, content, content_vector)
|
||||
@@ -1162,7 +1174,7 @@ SQL_TEMPLATES = {
|
||||
SET source_id=EXCLUDED.source_id,
|
||||
target_id=EXCLUDED.target_id,
|
||||
content=EXCLUDED.content,
|
||||
content_vector=EXCLUDED.content_vector, updatetime = CURRENT_TIMESTAMP
|
||||
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
# SQL for VectorStorage
|
||||
"entities": """SELECT entity_name FROM
|
||||
|
@@ -176,6 +176,8 @@ class LightRAG:
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
enable_llm_cache: bool = True
|
||||
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
|
||||
enable_llm_cache_for_entity_extract: bool = False
|
||||
|
||||
# extension
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
@@ -402,6 +404,7 @@ class LightRAG:
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
|
||||
|
@@ -253,9 +253,13 @@ async def extract_entities(
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
"enable_llm_cache_for_entity_extract"
|
||||
]
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
# add language and example number params to prompt
|
||||
@@ -300,6 +304,52 @@ async def extract_entities(
|
||||
already_entities = 0
|
||||
already_relations = 0
|
||||
|
||||
async def _user_llm_func_with_cache(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
) -> str:
|
||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||
need_to_restore = False
|
||||
if (
|
||||
global_config["embedding_cache_config"]
|
||||
and global_config["embedding_cache_config"]["enabled"]
|
||||
):
|
||||
new_config = global_config.copy()
|
||||
new_config["embedding_cache_config"] = None
|
||||
new_config["enable_llm_cache"] = True
|
||||
llm_response_cache.global_config = new_config
|
||||
need_to_restore = True
|
||||
if history_messages:
|
||||
history = json.dumps(history_messages)
|
||||
_prompt = history + "\n" + input_text
|
||||
else:
|
||||
_prompt = input_text
|
||||
|
||||
arg_hash = compute_args_hash(_prompt)
|
||||
cached_return, _1, _2, _3 = await handle_cache(
|
||||
llm_response_cache, arg_hash, _prompt, "default"
|
||||
)
|
||||
if need_to_restore:
|
||||
llm_response_cache.global_config = global_config
|
||||
if cached_return:
|
||||
return cached_return
|
||||
|
||||
if history_messages:
|
||||
res: str = await use_llm_func(
|
||||
input_text, history_messages=history_messages
|
||||
)
|
||||
else:
|
||||
res: str = await use_llm_func(input_text)
|
||||
await save_to_cache(
|
||||
llm_response_cache,
|
||||
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
|
||||
)
|
||||
return res
|
||||
|
||||
if history_messages:
|
||||
return await use_llm_func(input_text, history_messages=history_messages)
|
||||
else:
|
||||
return await use_llm_func(input_text)
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
nonlocal already_processed, already_entities, already_relations
|
||||
chunk_key = chunk_key_dp[0]
|
||||
@@ -310,17 +360,19 @@ async def extract_entities(
|
||||
**context_base, input_text="{input_text}"
|
||||
).format(**context_base, input_text=content)
|
||||
|
||||
final_result = await use_llm_func(hint_prompt)
|
||||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
for now_glean_index in range(entity_extract_max_gleaning):
|
||||
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
||||
glean_result = await _user_llm_func_with_cache(
|
||||
continue_prompt, history_messages=history
|
||||
)
|
||||
|
||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||
final_result += glean_result
|
||||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||||
break
|
||||
|
||||
if_loop_result: str = await use_llm_func(
|
||||
if_loop_result: str = await _user_llm_func_with_cache(
|
||||
if_loop_prompt, history_messages=history
|
||||
)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
|
@@ -454,7 +454,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
|
||||
# For naive mode, only use simple cache matching
|
||||
if mode == "naive":
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
return None, None, None, None
|
||||
@@ -488,7 +491,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# Use regular cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
@@ -510,7 +516,13 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
||||
return
|
||||
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = (
|
||||
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
||||
or {}
|
||||
)
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
@@ -543,3 +555,15 @@ def safe_unicode_decode(content):
|
||||
)
|
||||
|
||||
return decoded_content
|
||||
|
||||
|
||||
def exists_func(obj, func_name: str) -> bool:
|
||||
"""Check if a function exists in an object or not.
|
||||
:param obj:
|
||||
:param func_name:
|
||||
:return: True / False
|
||||
"""
|
||||
if callable(getattr(obj, func_name, None)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user