cleaned code

This commit is contained in:
Yannick Stephan
2025-02-09 13:18:47 +01:00
parent 263a301179
commit acbe3e2ff2
2 changed files with 133 additions and 127 deletions

View File

@@ -24,7 +24,6 @@ from .utils import (
convert_response_to_json,
logger,
set_logger,
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -177,7 +176,9 @@ class LightRAG:
# extension
addon_params: dict[str, Any] = field(default_factory=dict)
convert_response_to_json_func: Callable[[str], dict[str, Any]] = convert_response_to_json
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
@@ -360,7 +361,7 @@ class LightRAG:
storage.db = db_client
def insert(
self,
self,
string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False,
@@ -373,7 +374,7 @@ class LightRAG:
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
@@ -505,7 +506,7 @@ class LightRAG:
return
# 4. Store original document
await self.doc_status.upsert(new_docs)
await self.doc_status.upsert(new_docs)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(
@@ -541,23 +542,29 @@ class LightRAG:
if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates")
return
# If included in text_chunks is all processed, return
new_docs = await self.doc_status.get_by_ids(to_process_doc_keys)
text_chunks_new_docs_ids = await self.text_chunks.filter_keys(to_process_doc_keys)
text_chunks_new_docs_ids = await self.text_chunks.filter_keys(
to_process_doc_keys
)
full_docs_new_docs_ids = await self.full_docs.filter_keys(to_process_doc_keys)
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
# 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10)
batch_docs_list = [new_docs[i:i+batch_size] for i in range(0, len(new_docs), batch_size)]
batch_docs_list = [
new_docs[i : i + batch_size] for i in range(0, len(new_docs), batch_size)
]
for i, el in enumerate(batch_docs_list):
items = ((k, v) for d in el for k, v in d.items())
for doc_id, doc in tqdm_async(items, desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}"):
for doc_id, doc in tqdm_async(
items,
desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}",
):
doc_status: dict[str, Any] = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
@@ -567,7 +574,7 @@ class LightRAG:
}
try:
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
@@ -584,26 +591,27 @@ class LightRAG:
)
}
await self.chunks_vdb.upsert(chunks)
# Update status with chunks information
await self._process_entity_relation_graph(chunks)
if not doc_id in full_docs_new_docs_ids:
if doc_id not in full_docs_new_docs_ids:
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
if not doc_id in text_chunks_new_docs_ids:
{doc_id: {"content": doc["content"]}}
)
if doc_id not in text_chunks_new_docs_ids:
await self.text_chunks.upsert(chunks)
doc_status.update(
{
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
)
await self.doc_status.upsert({doc_id: doc_status})
await self._insert_done()
except Exception as e:
# Update status with failed information
@@ -620,122 +628,120 @@ class LightRAG:
)
continue
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try:
new_kg = await extract_entities(
chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if new_kg is None:
logger.info("No entities or relationships extracted!")
else:
self.chunk_entity_relation_graph = new_kg
except Exception as e:
logger.error("Failed to extract entities and relationships")
raise e
async def apipeline_process_extract_graph(self):
"""
Process pending or failed chunks to extract entities and relationships.
try:
new_kg = await extract_entities(
chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if new_kg is None:
logger.info("No entities or relationships extracted!")
else:
self.chunk_entity_relation_graph = new_kg
This method retrieves all chunks that are currently marked as pending or have previously failed.
It then extracts entities and relationships from each chunk and updates the status accordingly.
except Exception as e:
logger.error("Failed to extract entities and relationships")
raise e
Steps:
1. Retrieve all pending and failed chunks.
2. For each chunk, attempt to extract entities and relationships.
3. Update the chunk's status to processed if successful, or failed if an error occurs.
# async def apipeline_process_extract_graph(self):
# """
# Process pending or failed chunks to extract entities and relationships.
Raises:
Exception: If there is an error during the extraction process.
# This method retrieves all chunks that are currently marked as pending or have previously failed.
# It then extracts entities and relationships from each chunk and updates the status accordingly.
Returns:
None
"""
# 1. get all pending and failed chunks
to_process_doc_keys: list[str] = []
# Steps:
# 1. Retrieve all pending and failed chunks.
# 2. For each chunk, attempt to extract entities and relationships.
# 3. Update the chunk's status to processed if successful, or failed if an error occurs.
# Process failes
to_process_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Raises:
# Exception: If there is an error during the extraction process.
# Process Pending
to_process_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Returns:
# None
# """
# # 1. get all pending and failed chunks
# to_process_doc_keys: list[str] = []
if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates")
return
# # Process failes
# to_process_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
# if to_process_docs:
# to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
# # Process Pending
# to_process_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
# if to_process_docs:
# to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
semaphore = asyncio.Semaphore(
batch_size
) # Control the number of tasks that are processed simultaneously
# if not to_process_doc_keys:
# logger.info("All documents have been processed or are duplicates")
# return
async def process_chunk(chunk_id: str):
async with semaphore:
chunks: dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
async def _process_chunk(chunk_id: str):
chunks: dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
# # Process documents in batches
# batch_size = self.addon_params.get("insert_batch_size", 10)
# Extract and store entities and relationships
try:
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No entities or relationships extracted!")
# Update status to processed
await self.text_chunks.upsert(chunks)
await self.doc_status.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
except Exception as e:
logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails
await self.doc_status.upsert({chunk_id: {"status": DocStatus.FAILED}})
raise e
# semaphore = asyncio.Semaphore(
# batch_size
# ) # Control the number of tasks that are processed simultaneously
with tqdm_async(
total=len(to_process_doc_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0,
) as progress:
tasks: list[asyncio.Task[None]] = []
for chunk_id in to_process_doc_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)
# async def process_chunk(chunk_id: str):
# async with semaphore:
# chunks: dict[str, Any] = {
# i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
# }
# async def _process_chunk(chunk_id: str):
# chunks: dict[str, Any] = {
# i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
# }
for future in asyncio.as_completed(tasks):
await future
progress.update(1)
progress.set_postfix(
{
"LLM call": statistic_data["llm_call"],
"LLM cache": statistic_data["llm_cache"],
}
)
# # Extract and store entities and relationships
# try:
# maybe_new_kg = await extract_entities(
# chunks,
# knowledge_graph_inst=self.chunk_entity_relation_graph,
# entity_vdb=self.entities_vdb,
# relationships_vdb=self.relationships_vdb,
# llm_response_cache=self.llm_response_cache,
# global_config=asdict(self),
# )
# if maybe_new_kg is None:
# logger.warning("No entities or relationships extracted!")
# # Update status to processed
# await self.text_chunks.upsert(chunks)
# await self.doc_status.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
# except Exception as e:
# logger.error("Failed to extract entities and relationships")
# # Mark as failed if any step fails
# await self.doc_status.upsert({chunk_id: {"status": DocStatus.FAILED}})
# raise e
# Ensure all indexes are updated after each document
await self._insert_done()
# with tqdm_async(
# total=len(to_process_doc_keys),
# desc="\nLevel 1 - Processing chunks",
# unit="chunk",
# position=0,
# ) as progress:
# tasks: list[asyncio.Task[None]] = []
# for chunk_id in to_process_doc_keys:
# task = asyncio.create_task(process_chunk(chunk_id))
# tasks.append(task)
# for future in asyncio.as_completed(tasks):
# await future
# progress.update(1)
# progress.set_postfix(
# {
# "LLM call": statistic_data["llm_call"],
# "LLM cache": statistic_data["llm_cache"],
# }
# )
# # Ensure all indexes are updated after each document
async def _insert_done(self):
tasks = []