cleaned code
This commit is contained in:
@@ -24,7 +24,6 @@ from .utils import (
|
|||||||
convert_response_to_json,
|
convert_response_to_json,
|
||||||
logger,
|
logger,
|
||||||
set_logger,
|
set_logger,
|
||||||
statistic_data,
|
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -177,7 +176,9 @@ class LightRAG:
|
|||||||
|
|
||||||
# extension
|
# extension
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = convert_response_to_json
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||||
|
convert_response_to_json
|
||||||
|
)
|
||||||
|
|
||||||
# Add new field for document status storage type
|
# Add new field for document status storage type
|
||||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||||
@@ -544,8 +545,9 @@ class LightRAG:
|
|||||||
|
|
||||||
# If included in text_chunks is all processed, return
|
# If included in text_chunks is all processed, return
|
||||||
new_docs = await self.doc_status.get_by_ids(to_process_doc_keys)
|
new_docs = await self.doc_status.get_by_ids(to_process_doc_keys)
|
||||||
|
text_chunks_new_docs_ids = await self.text_chunks.filter_keys(
|
||||||
text_chunks_new_docs_ids = await self.text_chunks.filter_keys(to_process_doc_keys)
|
to_process_doc_keys
|
||||||
|
)
|
||||||
full_docs_new_docs_ids = await self.full_docs.filter_keys(to_process_doc_keys)
|
full_docs_new_docs_ids = await self.full_docs.filter_keys(to_process_doc_keys)
|
||||||
|
|
||||||
if not new_docs:
|
if not new_docs:
|
||||||
@@ -554,10 +556,15 @@ class LightRAG:
|
|||||||
|
|
||||||
# 2. split docs into chunks, insert chunks, update doc status
|
# 2. split docs into chunks, insert chunks, update doc status
|
||||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||||
batch_docs_list = [new_docs[i:i+batch_size] for i in range(0, len(new_docs), batch_size)]
|
batch_docs_list = [
|
||||||
|
new_docs[i : i + batch_size] for i in range(0, len(new_docs), batch_size)
|
||||||
|
]
|
||||||
for i, el in enumerate(batch_docs_list):
|
for i, el in enumerate(batch_docs_list):
|
||||||
items = ((k, v) for d in el for k, v in d.items())
|
items = ((k, v) for d in el for k, v in d.items())
|
||||||
for doc_id, doc in tqdm_async(items, desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}"):
|
for doc_id, doc in tqdm_async(
|
||||||
|
items,
|
||||||
|
desc=f"Level 1 - Spliting doc in batch {i // len(batch_docs_list) + 1}",
|
||||||
|
):
|
||||||
doc_status: dict[str, Any] = {
|
doc_status: dict[str, Any] = {
|
||||||
"content_summary": doc["content_summary"],
|
"content_summary": doc["content_summary"],
|
||||||
"content_length": doc["content_length"],
|
"content_length": doc["content_length"],
|
||||||
@@ -588,12 +595,12 @@ class LightRAG:
|
|||||||
# Update status with chunks information
|
# Update status with chunks information
|
||||||
await self._process_entity_relation_graph(chunks)
|
await self._process_entity_relation_graph(chunks)
|
||||||
|
|
||||||
if not doc_id in full_docs_new_docs_ids:
|
if doc_id not in full_docs_new_docs_ids:
|
||||||
await self.full_docs.upsert(
|
await self.full_docs.upsert(
|
||||||
{doc_id: {"content": doc["content"]}}
|
{doc_id: {"content": doc["content"]}}
|
||||||
)
|
)
|
||||||
|
|
||||||
if not doc_id in text_chunks_new_docs_ids:
|
if doc_id not in text_chunks_new_docs_ids:
|
||||||
await self.text_chunks.upsert(chunks)
|
await self.text_chunks.upsert(chunks)
|
||||||
|
|
||||||
doc_status.update(
|
doc_status.update(
|
||||||
@@ -604,6 +611,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
await self.doc_status.upsert({doc_id: doc_status})
|
await self.doc_status.upsert({doc_id: doc_status})
|
||||||
|
await self._insert_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Update status with failed information
|
# Update status with failed information
|
||||||
@@ -620,7 +628,6 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||||
try:
|
try:
|
||||||
new_kg = await extract_entities(
|
new_kg = await extract_entities(
|
||||||
@@ -640,102 +647,101 @@ class LightRAG:
|
|||||||
logger.error("Failed to extract entities and relationships")
|
logger.error("Failed to extract entities and relationships")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def apipeline_process_extract_graph(self):
|
# async def apipeline_process_extract_graph(self):
|
||||||
"""
|
# """
|
||||||
Process pending or failed chunks to extract entities and relationships.
|
# Process pending or failed chunks to extract entities and relationships.
|
||||||
|
|
||||||
This method retrieves all chunks that are currently marked as pending or have previously failed.
|
# This method retrieves all chunks that are currently marked as pending or have previously failed.
|
||||||
It then extracts entities and relationships from each chunk and updates the status accordingly.
|
# It then extracts entities and relationships from each chunk and updates the status accordingly.
|
||||||
|
|
||||||
Steps:
|
# Steps:
|
||||||
1. Retrieve all pending and failed chunks.
|
# 1. Retrieve all pending and failed chunks.
|
||||||
2. For each chunk, attempt to extract entities and relationships.
|
# 2. For each chunk, attempt to extract entities and relationships.
|
||||||
3. Update the chunk's status to processed if successful, or failed if an error occurs.
|
# 3. Update the chunk's status to processed if successful, or failed if an error occurs.
|
||||||
|
|
||||||
Raises:
|
# Raises:
|
||||||
Exception: If there is an error during the extraction process.
|
# Exception: If there is an error during the extraction process.
|
||||||
|
|
||||||
Returns:
|
# Returns:
|
||||||
None
|
# None
|
||||||
"""
|
# """
|
||||||
# 1. get all pending and failed chunks
|
# # 1. get all pending and failed chunks
|
||||||
to_process_doc_keys: list[str] = []
|
# to_process_doc_keys: list[str] = []
|
||||||
|
|
||||||
# Process failes
|
# # Process failes
|
||||||
to_process_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
|
# to_process_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
|
||||||
if to_process_docs:
|
# if to_process_docs:
|
||||||
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
|
# to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
|
||||||
|
|
||||||
# Process Pending
|
# # Process Pending
|
||||||
to_process_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
|
# to_process_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
|
||||||
if to_process_docs:
|
# if to_process_docs:
|
||||||
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
|
# to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
|
||||||
|
|
||||||
if not to_process_doc_keys:
|
# if not to_process_doc_keys:
|
||||||
logger.info("All documents have been processed or are duplicates")
|
# logger.info("All documents have been processed or are duplicates")
|
||||||
return
|
# return
|
||||||
|
|
||||||
# Process documents in batches
|
# # Process documents in batches
|
||||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
# batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(
|
# semaphore = asyncio.Semaphore(
|
||||||
batch_size
|
# batch_size
|
||||||
) # Control the number of tasks that are processed simultaneously
|
# ) # Control the number of tasks that are processed simultaneously
|
||||||
|
|
||||||
async def process_chunk(chunk_id: str):
|
# async def process_chunk(chunk_id: str):
|
||||||
async with semaphore:
|
# async with semaphore:
|
||||||
chunks: dict[str, Any] = {
|
# chunks: dict[str, Any] = {
|
||||||
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
|
# i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
|
||||||
}
|
# }
|
||||||
async def _process_chunk(chunk_id: str):
|
# async def _process_chunk(chunk_id: str):
|
||||||
chunks: dict[str, Any] = {
|
# chunks: dict[str, Any] = {
|
||||||
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
|
# i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
|
||||||
}
|
# }
|
||||||
|
|
||||||
# Extract and store entities and relationships
|
# # Extract and store entities and relationships
|
||||||
try:
|
# try:
|
||||||
maybe_new_kg = await extract_entities(
|
# maybe_new_kg = await extract_entities(
|
||||||
chunks,
|
# chunks,
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
# knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entity_vdb=self.entities_vdb,
|
# entity_vdb=self.entities_vdb,
|
||||||
relationships_vdb=self.relationships_vdb,
|
# relationships_vdb=self.relationships_vdb,
|
||||||
llm_response_cache=self.llm_response_cache,
|
# llm_response_cache=self.llm_response_cache,
|
||||||
global_config=asdict(self),
|
# global_config=asdict(self),
|
||||||
)
|
# )
|
||||||
if maybe_new_kg is None:
|
# if maybe_new_kg is None:
|
||||||
logger.warning("No entities or relationships extracted!")
|
# logger.warning("No entities or relationships extracted!")
|
||||||
# Update status to processed
|
# # Update status to processed
|
||||||
await self.text_chunks.upsert(chunks)
|
# await self.text_chunks.upsert(chunks)
|
||||||
await self.doc_status.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
|
# await self.doc_status.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error("Failed to extract entities and relationships")
|
# logger.error("Failed to extract entities and relationships")
|
||||||
# Mark as failed if any step fails
|
# # Mark as failed if any step fails
|
||||||
await self.doc_status.upsert({chunk_id: {"status": DocStatus.FAILED}})
|
# await self.doc_status.upsert({chunk_id: {"status": DocStatus.FAILED}})
|
||||||
raise e
|
# raise e
|
||||||
|
|
||||||
with tqdm_async(
|
# with tqdm_async(
|
||||||
total=len(to_process_doc_keys),
|
# total=len(to_process_doc_keys),
|
||||||
desc="\nLevel 1 - Processing chunks",
|
# desc="\nLevel 1 - Processing chunks",
|
||||||
unit="chunk",
|
# unit="chunk",
|
||||||
position=0,
|
# position=0,
|
||||||
) as progress:
|
# ) as progress:
|
||||||
tasks: list[asyncio.Task[None]] = []
|
# tasks: list[asyncio.Task[None]] = []
|
||||||
for chunk_id in to_process_doc_keys:
|
# for chunk_id in to_process_doc_keys:
|
||||||
task = asyncio.create_task(process_chunk(chunk_id))
|
# task = asyncio.create_task(process_chunk(chunk_id))
|
||||||
tasks.append(task)
|
# tasks.append(task)
|
||||||
|
|
||||||
for future in asyncio.as_completed(tasks):
|
# for future in asyncio.as_completed(tasks):
|
||||||
await future
|
# await future
|
||||||
progress.update(1)
|
# progress.update(1)
|
||||||
progress.set_postfix(
|
# progress.set_postfix(
|
||||||
{
|
# {
|
||||||
"LLM call": statistic_data["llm_call"],
|
# "LLM call": statistic_data["llm_call"],
|
||||||
"LLM cache": statistic_data["llm_cache"],
|
# "LLM cache": statistic_data["llm_cache"],
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Ensure all indexes are updated after each document
|
# # Ensure all indexes are updated after each document
|
||||||
await self._insert_done()
|
|
||||||
|
|
||||||
async def _insert_done(self):
|
async def _insert_done(self):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
@@ -36,11 +36,11 @@ import time
|
|||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str,
|
content: str,
|
||||||
split_by_character: Union[str, None]=None,
|
split_by_character: Union[str, None] = None,
|
||||||
split_by_character_only: bool =False,
|
split_by_character_only: bool = False,
|
||||||
overlap_token_size: int =128,
|
overlap_token_size: int = 128,
|
||||||
max_token_size: int =1024,
|
max_token_size: int = 1024,
|
||||||
tiktoken_model: str="gpt-4o"
|
tiktoken_model: str = "gpt-4o",
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
|
Reference in New Issue
Block a user