Improve Entity Extraction Robustness for Truncated LLM Responses
This commit is contained in:
@@ -141,18 +141,36 @@ async def _handle_single_entity_extraction(
|
|||||||
):
|
):
|
||||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||||
return None
|
return None
|
||||||
# add this record as a node in the G
|
|
||||||
|
# Clean and validate entity name
|
||||||
entity_name = clean_str(record_attributes[1]).strip('"')
|
entity_name = clean_str(record_attributes[1]).strip('"')
|
||||||
if not entity_name.strip():
|
if not entity_name.strip():
|
||||||
|
logger.warning(
|
||||||
|
f"Entity extraction error: empty entity name in: {record_attributes}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Clean and validate entity type
|
||||||
entity_type = clean_str(record_attributes[2]).strip('"')
|
entity_type = clean_str(record_attributes[2]).strip('"')
|
||||||
|
if not entity_type.strip() or entity_type.startswith('("'):
|
||||||
|
logger.warning(
|
||||||
|
f"Entity extraction error: invalid entity type in: {record_attributes}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean and validate description
|
||||||
entity_description = clean_str(record_attributes[3]).strip('"')
|
entity_description = clean_str(record_attributes[3]).strip('"')
|
||||||
entity_source_id = chunk_key
|
if not entity_description.strip():
|
||||||
|
logger.warning(
|
||||||
|
f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
entity_name=entity_name,
|
entity_name=entity_name,
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
description=entity_description,
|
description=entity_description,
|
||||||
source_id=entity_source_id,
|
source_id=chunk_key,
|
||||||
metadata={"created_at": time.time()},
|
metadata={"created_at": time.time()},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -438,47 +456,22 @@ async def extract_entities(
|
|||||||
else:
|
else:
|
||||||
return await use_llm_func(input_text)
|
return await use_llm_func(input_text)
|
||||||
|
|
||||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
async def _process_extraction_result(result: str, chunk_key: str):
|
||||||
""" "Prpocess a single chunk
|
"""Process a single extraction result (either initial or gleaning)
|
||||||
Args:
|
Args:
|
||||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
result (str): The extraction result to process
|
||||||
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
chunk_key (str): The chunk key for source tracking
|
||||||
|
Returns:
|
||||||
|
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
|
||||||
"""
|
"""
|
||||||
nonlocal processed_chunks
|
|
||||||
chunk_key = chunk_key_dp[0]
|
|
||||||
chunk_dp = chunk_key_dp[1]
|
|
||||||
content = chunk_dp["content"]
|
|
||||||
# hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
|
||||||
hint_prompt = entity_extract_prompt.format(
|
|
||||||
**context_base, input_text="{input_text}"
|
|
||||||
).format(**context_base, input_text=content)
|
|
||||||
|
|
||||||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
|
||||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
|
||||||
for now_glean_index in range(entity_extract_max_gleaning):
|
|
||||||
glean_result = await _user_llm_func_with_cache(
|
|
||||||
continue_prompt, history_messages=history
|
|
||||||
)
|
|
||||||
|
|
||||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
|
||||||
final_result += glean_result
|
|
||||||
if now_glean_index == entity_extract_max_gleaning - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
if_loop_result: str = await _user_llm_func_with_cache(
|
|
||||||
if_loop_prompt, history_messages=history
|
|
||||||
)
|
|
||||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
|
||||||
if if_loop_result != "yes":
|
|
||||||
break
|
|
||||||
|
|
||||||
records = split_string_by_multi_markers(
|
|
||||||
final_result,
|
|
||||||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
maybe_nodes = defaultdict(list)
|
maybe_nodes = defaultdict(list)
|
||||||
maybe_edges = defaultdict(list)
|
maybe_edges = defaultdict(list)
|
||||||
|
|
||||||
|
records = split_string_by_multi_markers(
|
||||||
|
result,
|
||||||
|
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||||||
|
)
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
record = re.search(r"\((.*)\)", record)
|
record = re.search(r"\((.*)\)", record)
|
||||||
if record is None:
|
if record is None:
|
||||||
@@ -487,13 +480,14 @@ async def extract_entities(
|
|||||||
record_attributes = split_string_by_multi_markers(
|
record_attributes = split_string_by_multi_markers(
|
||||||
record, [context_base["tuple_delimiter"]]
|
record, [context_base["tuple_delimiter"]]
|
||||||
)
|
)
|
||||||
|
|
||||||
if_entities = await _handle_single_entity_extraction(
|
if_entities = await _handle_single_entity_extraction(
|
||||||
record_attributes, chunk_key
|
record_attributes, chunk_key
|
||||||
)
|
)
|
||||||
if if_entities is not None:
|
if if_entities is not None:
|
||||||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if_relation = await _handle_single_relationship_extraction(
|
if_relation = await _handle_single_relationship_extraction(
|
||||||
record_attributes, chunk_key
|
record_attributes, chunk_key
|
||||||
)
|
)
|
||||||
@@ -501,6 +495,58 @@ async def extract_entities(
|
|||||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||||||
if_relation
|
if_relation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return maybe_nodes, maybe_edges
|
||||||
|
|
||||||
|
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||||
|
"""Process a single chunk
|
||||||
|
Args:
|
||||||
|
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||||
|
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||||
|
"""
|
||||||
|
nonlocal processed_chunks
|
||||||
|
chunk_key = chunk_key_dp[0]
|
||||||
|
chunk_dp = chunk_key_dp[1]
|
||||||
|
content = chunk_dp["content"]
|
||||||
|
|
||||||
|
# Get initial extraction
|
||||||
|
hint_prompt = entity_extract_prompt.format(
|
||||||
|
**context_base, input_text="{input_text}"
|
||||||
|
).format(**context_base, input_text=content)
|
||||||
|
|
||||||
|
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||||||
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||||
|
|
||||||
|
# Process initial extraction
|
||||||
|
maybe_nodes, maybe_edges = await _process_extraction_result(final_result, chunk_key)
|
||||||
|
|
||||||
|
# Process additional gleaning results
|
||||||
|
for now_glean_index in range(entity_extract_max_gleaning):
|
||||||
|
glean_result = await _user_llm_func_with_cache(
|
||||||
|
continue_prompt, history_messages=history
|
||||||
|
)
|
||||||
|
|
||||||
|
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||||
|
|
||||||
|
# Process gleaning result separately
|
||||||
|
glean_nodes, glean_edges = await _process_extraction_result(glean_result, chunk_key)
|
||||||
|
|
||||||
|
# Merge results
|
||||||
|
for entity_name, entities in glean_nodes.items():
|
||||||
|
maybe_nodes[entity_name].extend(entities)
|
||||||
|
for edge_key, edges in glean_edges.items():
|
||||||
|
maybe_edges[edge_key].extend(edges)
|
||||||
|
|
||||||
|
if now_glean_index == entity_extract_max_gleaning - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
if_loop_result: str = await _user_llm_func_with_cache(
|
||||||
|
if_loop_prompt, history_messages=history
|
||||||
|
)
|
||||||
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||||
|
if if_loop_result != "yes":
|
||||||
|
break
|
||||||
|
|
||||||
processed_chunks += 1
|
processed_chunks += 1
|
||||||
entities_count = len(maybe_nodes)
|
entities_count = len(maybe_nodes)
|
||||||
relations_count = len(maybe_edges)
|
relations_count = len(maybe_edges)
|
||||||
|
Reference in New Issue
Block a user