cleaned code

This commit is contained in:
Yannick Stephan
2025-02-09 13:18:47 +01:00
parent 263a301179
commit acbe3e2ff2
2 changed files with 133 additions and 127 deletions

View File

@@ -24,7 +24,6 @@ from .utils import (
convert_response_to_json, convert_response_to_json,
logger, logger,
set_logger, set_logger,
statistic_data,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -177,7 +176,9 @@ class LightRAG:
# extension # extension
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
convert_response_to_json_func: Callable[[str], dict[str, Any]] = convert_response_to_json convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
# Add new field for document status storage type # Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage") doc_status_storage: str = field(default="JsonDocStatusStorage")
@@ -544,8 +545,9 @@ class LightRAG:
# If included in text_chunks is all processed, return # If included in text_chunks is all processed, return
new_docs = await self.doc_status.get_by_ids(to_process_doc_keys) new_docs = await self.doc_status.get_by_ids(to_process_doc_keys)
text_chunks_new_docs_ids = await self.text_chunks.filter_keys(
text_chunks_new_docs_ids = await self.text_chunks.filter_keys(to_process_doc_keys) to_process_doc_keys
)
full_docs_new_docs_ids = await self.full_docs.filter_keys(to_process_doc_keys) full_docs_new_docs_ids = await self.full_docs.filter_keys(to_process_doc_keys)
if not new_docs: if not new_docs:
@@ -554,10 +556,15 @@ class LightRAG:
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10) batch_size = self.addon_params.get("insert_batch_size", 10)
batch_docs_list = [new_docs[i:i+batch_size] for i in range(0, len(new_docs), batch_size)] batch_docs_list = [
new_docs[i : i + batch_size] for i in range(0, len(new_docs), batch_size)
]
for i, el in enumerate(batch_docs_list): for i, el in enumerate(batch_docs_list):
items = ((k, v) for d in el for k, v in d.items()) items = ((k, v) for d in el for k, v in d.items())
for doc_id, doc in tqdm_async(items, desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}"): for doc_id, doc in tqdm_async(
items,
desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}",
):
doc_status: dict[str, Any] = { doc_status: dict[str, Any] = {
"content_summary": doc["content_summary"], "content_summary": doc["content_summary"],
"content_length": doc["content_length"], "content_length": doc["content_length"],
@@ -588,12 +595,12 @@ class LightRAG:
# Update status with chunks information # Update status with chunks information
await self._process_entity_relation_graph(chunks) await self._process_entity_relation_graph(chunks)
if not doc_id in full_docs_new_docs_ids: if doc_id not in full_docs_new_docs_ids:
await self.full_docs.upsert( await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}} {doc_id: {"content": doc["content"]}}
) )
if not doc_id in text_chunks_new_docs_ids: if doc_id not in text_chunks_new_docs_ids:
await self.text_chunks.upsert(chunks) await self.text_chunks.upsert(chunks)
doc_status.update( doc_status.update(
@@ -604,6 +611,7 @@ class LightRAG:
} }
) )
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
await self._insert_done()
except Exception as e: except Exception as e:
# Update status with failed information # Update status with failed information
@@ -620,122 +628,120 @@ class LightRAG:
) )
continue continue
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try: try:
new_kg = await extract_entities( new_kg = await extract_entities(
chunk, chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
global_config=asdict(self), global_config=asdict(self),
) )
if new_kg is None: if new_kg is None:
logger.info("No entities or relationships extracted!") logger.info("No entities or relationships extracted!")
else: else:
self.chunk_entity_relation_graph = new_kg self.chunk_entity_relation_graph = new_kg
except Exception as e: except Exception as e:
logger.error("Failed to extract entities and relationships") logger.error("Failed to extract entities and relationships")
raise e raise e
async def apipeline_process_extract_graph(self): # async def apipeline_process_extract_graph(self):
""" # """
Process pending or failed chunks to extract entities and relationships. # Process pending or failed chunks to extract entities and relationships.
This method retrieves all chunks that are currently marked as pending or have previously failed. # This method retrieves all chunks that are currently marked as pending or have previously failed.
It then extracts entities and relationships from each chunk and updates the status accordingly. # It then extracts entities and relationships from each chunk and updates the status accordingly.
Steps: # Steps:
1. Retrieve all pending and failed chunks. # 1. Retrieve all pending and failed chunks.
2. For each chunk, attempt to extract entities and relationships. # 2. For each chunk, attempt to extract entities and relationships.
3. Update the chunk's status to processed if successful, or failed if an error occurs. # 3. Update the chunk's status to processed if successful, or failed if an error occurs.
Raises: # Raises:
Exception: If there is an error during the extraction process. # Exception: If there is an error during the extraction process.
Returns: # Returns:
None # None
""" # """
# 1. get all pending and failed chunks # # 1. get all pending and failed chunks
to_process_doc_keys: list[str] = [] # to_process_doc_keys: list[str] = []
# Process failes # # Process failes
to_process_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED) # to_process_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
if to_process_docs: # if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) # to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending # # Process Pending
to_process_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING) # to_process_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
if to_process_docs: # if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) # to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if not to_process_doc_keys: # if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates") # logger.info("All documents have been processed or are duplicates")
return # return
# Process documents in batches # # Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10) # batch_size = self.addon_params.get("insert_batch_size", 10)
semaphore = asyncio.Semaphore( # semaphore = asyncio.Semaphore(
batch_size # batch_size
) # Control the number of tasks that are processed simultaneously # ) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id: str): # async def process_chunk(chunk_id: str):
async with semaphore: # async with semaphore:
chunks: dict[str, Any] = { # chunks: dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) # i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
} # }
async def _process_chunk(chunk_id: str): # async def _process_chunk(chunk_id: str):
chunks: dict[str, Any] = { # chunks: dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) # i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
} # }
# Extract and store entities and relationships # # Extract and store entities and relationships
try: # try:
maybe_new_kg = await extract_entities( # maybe_new_kg = await extract_entities(
chunks, # chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph, # knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, # entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, # relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache, # llm_response_cache=self.llm_response_cache,
global_config=asdict(self), # global_config=asdict(self),
) # )
if maybe_new_kg is None: # if maybe_new_kg is None:
logger.warning("No entities or relationships extracted!") # logger.warning("No entities or relationships extracted!")
# Update status to processed # # Update status to processed
await self.text_chunks.upsert(chunks) # await self.text_chunks.upsert(chunks)
await self.doc_status.upsert({chunk_id: {"status": DocStatus.PROCESSED}}) # await self.doc_status.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
except Exception as e: # except Exception as e:
logger.error("Failed to extract entities and relationships") # logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails # # Mark as failed if any step fails
await self.doc_status.upsert({chunk_id: {"status": DocStatus.FAILED}}) # await self.doc_status.upsert({chunk_id: {"status": DocStatus.FAILED}})
raise e # raise e
with tqdm_async( # with tqdm_async(
total=len(to_process_doc_keys), # total=len(to_process_doc_keys),
desc="\nLevel 1 - Processing chunks", # desc="\nLevel 1 - Processing chunks",
unit="chunk", # unit="chunk",
position=0, # position=0,
) as progress: # ) as progress:
tasks: list[asyncio.Task[None]] = [] # tasks: list[asyncio.Task[None]] = []
for chunk_id in to_process_doc_keys: # for chunk_id in to_process_doc_keys:
task = asyncio.create_task(process_chunk(chunk_id)) # task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task) # tasks.append(task)
for future in asyncio.as_completed(tasks): # for future in asyncio.as_completed(tasks):
await future # await future
progress.update(1) # progress.update(1)
progress.set_postfix( # progress.set_postfix(
{ # {
"LLM call": statistic_data["llm_call"], # "LLM call": statistic_data["llm_call"],
"LLM cache": statistic_data["llm_cache"], # "LLM cache": statistic_data["llm_cache"],
} # }
) # )
# Ensure all indexes are updated after each document # # Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self): async def _insert_done(self):
tasks = [] tasks = []

View File

@@ -36,11 +36,11 @@ import time
def chunking_by_token_size( def chunking_by_token_size(
content: str, content: str,
split_by_character: Union[str, None]=None, split_by_character: Union[str, None] = None,
split_by_character_only: bool =False, split_by_character_only: bool = False,
overlap_token_size: int =128, overlap_token_size: int = 128,
max_token_size: int =1024, max_token_size: int = 1024,
tiktoken_model: str="gpt-4o" tiktoken_model: str = "gpt-4o",
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []