Refactor context handling to convert data from CSV to JSON format for improved compatibility with LLM, replacing the list_of_list_to_csv function with list_of_list_to_json

This commit is contained in:
mengchao
2025-04-20 19:24:05 +08:00
parent 6d486f5813
commit f2f3a2721d
2 changed files with 42 additions and 79 deletions

View File

@@ -15,7 +15,6 @@ from .utils import (
decode_tokens_by_tiktoken, decode_tokens_by_tiktoken,
encode_string_by_tiktoken, encode_string_by_tiktoken,
is_float_regex, is_float_regex,
list_of_list_to_csv,
normalize_extracted_info, normalize_extracted_info,
pack_user_ass_to_openai_messages, pack_user_ass_to_openai_messages,
split_string_by_multi_markers, split_string_by_multi_markers,
@@ -27,6 +26,7 @@ from .utils import (
CacheData, CacheData,
get_conversation_turns, get_conversation_turns,
use_llm_func_with_cache, use_llm_func_with_cache,
list_of_list_to_json,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -1311,21 +1311,26 @@ async def _build_query_context(
[hl_text_units_context, ll_text_units_context], [hl_text_units_context, ll_text_units_context],
) )
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context.strip() and not relations_context.strip(): if not entities_context and not relations_context:
return None return None
# 转换为 JSON 字符串
entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
result = f""" result = f"""
-----Entities----- -----Entities-----
```json ```json
{entities_context} {entities_str}
``` ```
-----Relationships----- -----Relationships-----
```json ```json
{relations_context} {relations_str}
``` ```
-----Sources----- -----Sources-----
```json ```json
{text_units_context} {text_units_str}
``` ```
""".strip() """.strip()
return result return result
@@ -1424,7 +1429,7 @@ async def _get_node_data(
file_path, file_path,
] ]
) )
entities_context = list_of_list_to_csv(entites_section_list) entities_context = list_of_list_to_json(entites_section_list)
relations_section_list = [ relations_section_list = [
[ [
@@ -1461,14 +1466,14 @@ async def _get_node_data(
file_path, file_path,
] ]
) )
relations_context = list_of_list_to_csv(relations_section_list) relations_context = list_of_list_to_json(relations_section_list)
text_units_section_list = [["id", "content", "file_path"]] text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append( text_units_section_list.append(
[i, t["content"], t.get("file_path", "unknown_source")] [i, t["content"], t.get("file_path", "unknown_source")]
) )
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_json(text_units_section_list)
return entities_context, relations_context, text_units_context return entities_context, relations_context, text_units_context
@@ -1736,7 +1741,7 @@ async def _get_edge_data(
file_path, file_path,
] ]
) )
relations_context = list_of_list_to_csv(relations_section_list) relations_context = list_of_list_to_json(relations_section_list)
entites_section_list = [ entites_section_list = [
["id", "entity", "type", "description", "rank", "created_at", "file_path"] ["id", "entity", "type", "description", "rank", "created_at", "file_path"]
@@ -1761,12 +1766,12 @@ async def _get_edge_data(
file_path, file_path,
] ]
) )
entities_context = list_of_list_to_csv(entites_section_list) entities_context = list_of_list_to_json(entites_section_list)
text_units_section_list = [["id", "content", "file_path"]] text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")]) text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_json(text_units_section_list)
return entities_context, relations_context, text_units_context return entities_context, relations_context, text_units_context

View File

@@ -374,37 +374,24 @@ def truncate_list_by_token_size(
return list_data return list_data
def list_of_list_to_csv(data: list[list[str]]) -> str: def list_of_list_to_json(data: list[list[str]]) -> list[dict[str, str]]:
output = io.StringIO() if not data or len(data) <= 1:
writer = csv.writer( return []
output,
quoting=csv.QUOTE_ALL, # Quote all fields
escapechar="\\", # Use backslash as escape character
quotechar='"', # Use double quotes
lineterminator="\n", # Explicit line terminator
)
writer.writerows(data)
return output.getvalue()
header = data[0]
result = []
def csv_string_to_list(csv_string: str) -> list[list[str]]: for row in data[1:]:
# Clean the string by removing NUL characters if len(row) >= 2:
cleaned_string = csv_string.replace("\0", "") item = {}
for i, field_name in enumerate(header):
if i < len(row):
item[field_name] = row[i]
else:
item[field_name] = ""
result.append(item)
output = io.StringIO(cleaned_string) return result
reader = csv.reader(
output,
quoting=csv.QUOTE_ALL, # Match the writer configuration
escapechar="\\", # Use backslash as escape character
quotechar='"', # Use double quotes
)
try:
return [row for row in reader]
except csv.Error as e:
raise ValueError(f"Failed to parse CSV string: {str(e)}")
finally:
output.close()
def save_data_to_file(data, file_name): def save_data_to_file(data, file_name):
@@ -472,50 +459,21 @@ def xml_to_json(xml_file):
return None return None
def process_combine_contexts(hl: str, ll: str): def process_combine_contexts(hl_context: dict, ll_context: dict):
list_hl = csv_string_to_list(hl.strip()) if hl.strip() else [] seen_content = {}
list_ll = csv_string_to_list(ll.strip()) if ll.strip() else []
if not list_hl and not list_ll:
return json.dumps([], ensure_ascii=False)
header = None
if list_hl and len(list_hl) > 0:
header = list_hl[0]
list_hl = list_hl[1:]
if list_ll and len(list_ll) > 0:
if header is None:
header = list_ll[0]
list_ll = list_ll[1:]
if header is None:
return json.dumps([], ensure_ascii=False)
combined_data = [] combined_data = []
seen = set()
def process_row(row): for item in hl_context + ll_context:
if len(row) < 2: content_key = {k: v for k, v in item.items() if k != 'id'}
return None content_key_str = str(content_key)
if content_key_str not in seen_content:
item_data = {} seen_content[content_key_str] = item
for i, field_name in enumerate(header):
item_data[field_name] = row[i]
return item_data
for row in list_hl + list_ll:
if len(row) >= 2:
row_identifier = json.dumps(row[1:], ensure_ascii=False)
if row_identifier not in seen:
seen.add(row_identifier)
item = process_row(row)
if item:
combined_data.append(item) combined_data.append(item)
return json.dumps(combined_data, ensure_ascii=False) for i, item in enumerate(combined_data):
item['id'] = i
return combined_data
async def get_best_cached_response( async def get_best_cached_response(