support pipeline mode

This commit is contained in:
jin
2025-01-16 12:58:15 +08:00
parent d5ae6669ea
commit 6ae8647285
6 changed files with 203 additions and 172 deletions

2
.gitignore vendored
View File

@@ -21,4 +21,4 @@ rag_storage
venv/
examples/input/
examples/output/
.DS_Store
.DS_Store

View File

@@ -89,49 +89,45 @@ async def main():
rag = LightRAG(
# log_level="DEBUG",
working_dir=WORKING_DIR,
entity_extract_max_gleaning = 1,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract = True,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90},
enable_llm_cache_for_entity_extract=True,
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size = MAX_TOKENS,
llm_model_max_token_size=MAX_TOKENS,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=500,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage = "OracleKVStorage",
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params = {"example_number":1,
"language":"Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size":2,
}
addon_params={
"example_number": 1,
"language": "Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size": 2,
},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.set_storage_client(db_client = oracle_db)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.set_storage_client(db_client=oracle_db)
# Extract and Insert into LightRAG storage
with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f:
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read()
texts = [x for x in all_text.split("\n") if x]
# New mode use pipeline
await rag.apipeline_process_documents(texts)
await rag.apipeline_process_chunks()
await rag.apipeline_process_chunks()
await rag.apipeline_process_extract_graph()
# Old method use ainsert
#await rag.ainsert(texts)
# await rag.ainsert(texts)
# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:

View File

@@ -3,7 +3,7 @@ import asyncio
# import html
# import os
from dataclasses import dataclass
from typing import Union, List, Dict, Set, Any, Tuple
from typing import Union
import numpy as np
import array
@@ -170,7 +170,7 @@ class OracleKVStorage(BaseKVStorage):
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config.get("embedding_batch_num",10)
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
################ QUERY METHODS ################
@@ -190,7 +190,7 @@ class OracleKVStorage(BaseKVStorage):
return res
else:
return None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
@@ -199,11 +199,11 @@ class OracleKVStorage(BaseKVStorage):
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
res[row["id"]] = row
return res
else:
return None
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
@@ -222,7 +222,7 @@ class OracleKVStorage(BaseKVStorage):
dict_res[mode] = {}
for row in res:
dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()]
res = [{k: v} for k, v in dict_res.items()]
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -230,7 +230,9 @@ class OracleKVStorage(BaseKVStorage):
else:
return None
async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]:
async def get_by_status_and_ids(
self, status: str, ids: list[str]
) -> Union[list[dict], None]:
"""Specifically for llm_response_cache."""
if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
@@ -244,7 +246,7 @@ class OracleKVStorage(BaseKVStorage):
return res
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
@@ -258,7 +260,6 @@ class OracleKVStorage(BaseKVStorage):
return data
else:
return set(keys)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]):
@@ -281,7 +282,7 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
merge_sql = SQL_TEMPLATES["merge_chunk"]
for item in list_data:
_data = {
@@ -320,11 +321,9 @@ class OracleKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data)
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(
table_name=N_T[self.namespace]
)
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
@@ -673,8 +672,8 @@ TABLES = {
},
"LIGHTRAG_LLM_CACHE": {
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
id varchar(256) PRIMARY KEY,
workspace varchar(1024),
id varchar(256) PRIMARY KEY,
workspace varchar(1024),
cache_mode varchar(256),
model_name varchar(256),
original_prompt clob,
@@ -708,47 +707,32 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL
ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN INSERT
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL
ON (a.id = :id)
@@ -760,8 +744,6 @@ SQL_TEMPLATES = {
return_value = :return_value,
cache_mode = :cache_mode,
updatetime = SYSDATE""",
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
@@ -818,7 +800,7 @@ SQL_TEMPLATES = {
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
UPDATE SET
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
USING DUAL
@@ -827,7 +809,7 @@ SQL_TEMPLATES = {
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
UPDATE SET
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"get_all_nodes": """WITH t0 AS (
SELECT name AS id, entity_type AS label, entity_type, description,

View File

@@ -26,7 +26,7 @@ from .utils import (
convert_response_to_json,
logger,
set_logger,
statistic_data
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -39,30 +39,30 @@ from .base import (
from .prompt import GRAPH_FIELD_SEP
STORAGES = {
"JsonKVStorage": '.storage',
"NanoVectorDBStorage": '.storage',
"NetworkXStorage": '.storage',
"JsonDocStatusStorage": '.storage',
"Neo4JStorage":".kg.neo4j_impl",
"OracleKVStorage":".kg.oracle_impl",
"OracleGraphStorage":".kg.oracle_impl",
"OracleVectorDBStorage":".kg.oracle_impl",
"MilvusVectorDBStorge":".kg.milvus_impl",
"MongoKVStorage":".kg.mongo_impl",
"ChromaVectorDBStorage":".kg.chroma_impl",
"TiDBKVStorage":".kg.tidb_impl",
"TiDBVectorDBStorage":".kg.tidb_impl",
"TiDBGraphStorage":".kg.tidb_impl",
"PGKVStorage":".kg.postgres_impl",
"PGVectorStorage":".kg.postgres_impl",
"AGEStorage":".kg.age_impl",
"PGGraphStorage":".kg.postgres_impl",
"GremlinStorage":".kg.gremlin_impl",
"PGDocStatusStorage":".kg.postgres_impl",
STORAGES = {
"JsonKVStorage": ".storage",
"NanoVectorDBStorage": ".storage",
"NetworkXStorage": ".storage",
"JsonDocStatusStorage": ".storage",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
}
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
@@ -75,6 +75,7 @@ def lazy_external_import(module_name: str, class_name: str):
def import_class(*args, **kwargs):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
@@ -190,7 +191,7 @@ class LightRAG:
os.makedirs(self.working_dir)
# show config
global_config=asdict(self)
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -198,31 +199,33 @@ class LightRAG:
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage)
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
)
self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls,
global_config=global_config
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls,
global_config=global_config
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls,
global_config=global_config
self.graph_storage_cls, global_config=global_config
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace="json_doc_status_storage",
embedding_func=None,
embedding_func=None,
)
self.llm_response_cache = self.key_string_value_json_storage_cls(
@@ -264,13 +267,15 @@ class LightRAG:
embedding_func=self.embedding_func,
)
if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"):
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
)
namespace="llm_response_cache",
embedding_func=None,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
@@ -292,21 +297,24 @@ class LightRAG:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self,db_client):
def set_storage_client(self, db_client):
# Now only tested on Oracle Database
for storage in [self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status, self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache]:
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
@@ -348,11 +356,6 @@ class LightRAG:
}
for content in unique_contents
}
# 3. Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
# 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
@@ -401,7 +404,12 @@ class LightRAG:
}
# Update status with chunks information
doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()})
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
try:
@@ -425,16 +433,30 @@ class LightRAG:
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()})
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()})
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
@@ -527,7 +549,9 @@ class LightRAG:
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents")
logger.info(
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
)
# 2. Generate document IDs and initial status
new_docs = {
@@ -542,28 +566,34 @@ class LightRAG:
for content in unique_contents
}
# 3. Filter out already processed documents
# 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents")
logger.info(
f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents"
)
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
if not new_docs:
logger.info(f"All documents have been processed or are duplicates")
logger.info("All documents have been processed or are duplicates")
return None
# 4. Store original document
# 4. Store original document
for doc_id, doc in new_docs.items():
await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
await self.full_docs.change_status(doc_id, DocStatus.PENDING)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(self):
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
_todo_doc_keys = []
_failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None)
_pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None)
_failed_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.FAILED, ids=None
)
_pendding_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
@@ -573,10 +603,9 @@ class LightRAG:
return None
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
new_docs = {
doc["id"]: doc
for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
}
# 2. split docs into chunks, insert chunks, update doc status
@@ -585,8 +614,9 @@ class LightRAG:
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}"
):
batch_docs.items(),
desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}",
):
try:
# Generate chunks from document
chunks = {
@@ -616,18 +646,23 @@ class LightRAG:
await self.full_docs.change_status(doc_id, DocStatus.FAILED)
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks
# 1. get all pending and failed chunks
_todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None)
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.FAILED, ids=None
)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks:
@@ -635,15 +670,19 @@ class LightRAG:
if not _todo_chunk_keys:
logger.info("All chunks have been processed or are duplicates")
return None
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously
semaphore = asyncio.Semaphore(
batch_size
) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id):
async def process_chunk(chunk_id):
async with semaphore:
chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])}
chunks = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
# Extract and store entities and relationships
try:
maybe_new_kg = await extract_entities(
@@ -662,25 +701,29 @@ class LightRAG:
logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
raise e
raise e
with tqdm_async(total=len(_todo_chunk_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0) as progress:
with tqdm_async(
total=len(_todo_chunk_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0,
) as progress:
tasks = []
for chunk_id in _todo_chunk_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)
for future in asyncio.as_completed(tasks):
await future
progress.update(1)
progress.set_postfix({
'LLM call': statistic_data["llm_call"],
'LLM cache': statistic_data["llm_cache"],
})
progress.set_postfix(
{
"LLM call": statistic_data["llm_call"],
"LLM cache": statistic_data["llm_cache"],
}
)
# Ensure all indexes are updated after each document
await self._insert_done()

View File

@@ -20,7 +20,7 @@ from .utils import (
handle_cache,
save_to_cache,
CacheData,
statistic_data
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -105,7 +105,9 @@ async def _handle_entity_relation_summary(
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary
@@ -360,7 +362,7 @@ async def extract_entities(
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages,ensure_ascii=False)
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
@@ -381,7 +383,7 @@ async def extract_entities(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
@@ -394,7 +396,7 @@ async def extract_entities(
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
""""Prpocess a single chunk
""" "Prpocess a single chunk
Args:
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
@@ -472,7 +474,9 @@ async def extract_entities(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Level 2 - Extracting entities and relationships",
unit="chunk", position=1,leave=False
unit="chunk",
position=1,
leave=False,
):
results.append(await result)
@@ -494,7 +498,9 @@ async def extract_entities(
),
total=len(maybe_nodes),
desc="Level 3 - Inserting entities",
unit="entity", position=2,leave=False
unit="entity",
position=2,
leave=False,
):
all_entities_data.append(await result)
@@ -511,7 +517,9 @@ async def extract_entities(
),
total=len(maybe_edges),
desc="Level 3 - Inserting relationships",
unit="relationship", position=3,leave=False
unit="relationship",
position=3,
leave=False,
):
all_relationships_data.append(await result)

View File

@@ -41,7 +41,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
@@ -458,7 +458,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return None, None, None, None
# For naive mode, only use simple cache matching
#if mode == "naive":
# if mode == "naive":
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
@@ -479,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"]
embedding_model_func = hashing_kv.global_config[
"embedding_func"
].func # ["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])