friendly implementation of entity extraction and relationship weight extract for Low-Capability LLMs

This commit is contained in:
tackhwa
2025-04-21 16:52:13 +08:00
parent bfce14d41f
commit f3c57b606e
2 changed files with 31 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ from .utils import (
normalize_extracted_info, normalize_extracted_info,
pack_user_ass_to_openai_messages, pack_user_ass_to_openai_messages,
split_string_by_multi_markers, split_string_by_multi_markers,
extract_fixed_parenthesized_content,
truncate_list_by_token_size, truncate_list_by_token_size,
process_combine_contexts, process_combine_contexts,
compute_args_hash, compute_args_hash,
@@ -215,7 +216,7 @@ async def _handle_single_relationship_extraction(
edge_source_id = chunk_key edge_source_id = chunk_key
weight = ( weight = (
float(record_attributes[-1].strip('"').strip("'")) float(record_attributes[-1].strip('"').strip("'"))
if is_float_regex(record_attributes[-1]) if is_float_regex(record_attributes[-1].strip('"').strip("'"))
else 1.0 else 1.0
) )
return dict( return dict(
@@ -549,6 +550,8 @@ async def extract_entities(
[context_base["record_delimiter"], context_base["completion_delimiter"]], [context_base["record_delimiter"], context_base["completion_delimiter"]],
) )
records = extract_fixed_parenthesized_content(records)
for record in records: for record in records:
record = re.search(r"\((.*)\)", record) record = re.search(r"\((.*)\)", record)
if record is None: if record is None:

View File

@@ -408,6 +408,33 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
return [r.strip() for r in results if r.strip()] return [r.strip() for r in results if r.strip()]
def extract_fixed_parenthesized_content(records: list[str]) -> list[str]:
"""
Extract content that should be in parentheses from each record.
Ensures each extracted item has both opening and closing parentheses.
"""
result = []
for record in records:
# First, extract properly matched pairs
balanced_matches = re.findall(r'\((.*?)\)', record)
for match in balanced_matches:
result.append(f"({match})")
# Process string to handle unbalanced parentheses
# For opening without closing
open_matches = re.findall(r'\(([^()]*?)$', record)
for match in open_matches:
result.append(f"({match})")
# For closing without opening
close_matches = re.findall(r'^([^()]*?)\)', record)
for match in close_matches:
result.append(f"({match})")
return result
# Refer the utils functions of the official GraphRAG implementation: # Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag # https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str: def clean_str(input: Any) -> str: