support pipeline mode

This commit is contained in:
jin
2025-01-16 12:58:15 +08:00
parent d5ae6669ea
commit 6ae8647285
6 changed files with 203 additions and 172 deletions

View File

@@ -20,7 +20,7 @@ from .utils import (
handle_cache,
save_to_cache,
CacheData,
statistic_data
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -105,7 +105,9 @@ async def _handle_entity_relation_summary(
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary
@@ -360,7 +362,7 @@ async def extract_entities(
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages,ensure_ascii=False)
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
@@ -381,7 +383,7 @@ async def extract_entities(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
@@ -394,7 +396,7 @@ async def extract_entities(
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
""""Prpocess a single chunk
""" "Prpocess a single chunk
Args:
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
@@ -472,7 +474,9 @@ async def extract_entities(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Level 2 - Extracting entities and relationships",
unit="chunk", position=1,leave=False
unit="chunk",
position=1,
leave=False,
):
results.append(await result)
@@ -494,7 +498,9 @@ async def extract_entities(
),
total=len(maybe_nodes),
desc="Level 3 - Inserting entities",
unit="entity", position=2,leave=False
unit="entity",
position=2,
leave=False,
):
all_entities_data.append(await result)
@@ -511,7 +517,9 @@ async def extract_entities(
),
total=len(maybe_edges),
desc="Level 3 - Inserting relationships",
unit="relationship", position=3,leave=False
unit="relationship",
position=3,
leave=False,
):
all_relationships_data.append(await result)