Merge branch 'context_format_csv_to_json'

This commit is contained in:
yangdx
2025-04-22 12:25:50 +08:00
2 changed files with 48 additions and 75 deletions

View File

@@ -14,7 +14,6 @@ from .utils import (
compute_mdhash_id,
Tokenizer,
is_float_regex,
list_of_list_to_csv,
normalize_extracted_info,
pack_user_ass_to_openai_messages,
split_string_by_multi_markers,
@@ -26,6 +25,7 @@ from .utils import (
CacheData,
get_conversation_turns,
use_llm_func_with_cache,
list_of_list_to_json,
)
from .base import (
BaseGraphStorage,
@@ -1333,21 +1333,26 @@ async def _build_query_context(
[hl_text_units_context, ll_text_units_context],
)
# not necessary to use LLM to generate a response
if not entities_context.strip() and not relations_context.strip():
if not entities_context and not relations_context:
return None
# 转换为 JSON 字符串
entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
result = f"""
-----Entities-----
```csv
{entities_context}
```json
{entities_str}
```
-----Relationships-----
```csv
{relations_context}
```json
{relations_str}
```
-----Sources-----
```csv
{text_units_context}
```json
{text_units_str}
```
""".strip()
return result
@@ -1453,7 +1458,7 @@ async def _get_node_data(
file_path,
]
)
entities_context = list_of_list_to_csv(entites_section_list)
entities_context = list_of_list_to_json(entites_section_list)
relations_section_list = [
[
@@ -1490,14 +1495,14 @@ async def _get_node_data(
file_path,
]
)
relations_context = list_of_list_to_csv(relations_section_list)
relations_context = list_of_list_to_json(relations_section_list)
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append(
[i, t["content"], t.get("file_path", "unknown_source")]
)
text_units_context = list_of_list_to_csv(text_units_section_list)
text_units_context = list_of_list_to_json(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -1775,7 +1780,7 @@ async def _get_edge_data(
file_path,
]
)
relations_context = list_of_list_to_csv(relations_section_list)
relations_context = list_of_list_to_json(relations_section_list)
entites_section_list = [
["id", "entity", "type", "description", "rank", "created_at", "file_path"]
@@ -1800,12 +1805,12 @@ async def _get_edge_data(
file_path,
]
)
entities_context = list_of_list_to_csv(entites_section_list)
entities_context = list_of_list_to_json(entites_section_list)
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
text_units_context = list_of_list_to_csv(text_units_section_list)
text_units_context = list_of_list_to_json(text_units_section_list)
return entities_context, relations_context, text_units_context

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
import html
import io
import csv
import json
import logging
@@ -442,37 +441,24 @@ def truncate_list_by_token_size(
return list_data
def list_of_list_to_csv(data: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(
output,
quoting=csv.QUOTE_ALL, # Quote all fields
escapechar="\\", # Use backslash as escape character
quotechar='"', # Use double quotes
lineterminator="\n", # Explicit line terminator
)
writer.writerows(data)
return output.getvalue()
def list_of_list_to_json(data: list[list[str]]) -> list[dict[str, str]]:
if not data or len(data) <= 1:
return []
header = data[0]
result = []
def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters
cleaned_string = csv_string.replace("\0", "")
for row in data[1:]:
if len(row) >= 2:
item = {}
for i, field_name in enumerate(header):
if i < len(row):
item[field_name] = str(row[i])
else:
item[field_name] = ""
result.append(item)
output = io.StringIO(cleaned_string)
reader = csv.reader(
output,
quoting=csv.QUOTE_ALL, # Match the writer configuration
escapechar="\\", # Use backslash as escape character
quotechar='"', # Use double quotes
)
try:
return [row for row in reader]
except csv.Error as e:
raise ValueError(f"Failed to parse CSV string: {str(e)}")
finally:
output.close()
return result
def save_data_to_file(data, file_name):
@@ -540,41 +526,23 @@ def xml_to_json(xml_file):
return None
def process_combine_contexts(hl: str, ll: str):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
def process_combine_contexts(
hl_context: list[dict[str, str]], ll_context: list[dict[str, str]]
):
seen_content = {}
combined_data = []
if list_hl:
header = list_hl[0]
list_hl = list_hl[1:]
if list_ll:
header = list_ll[0]
list_ll = list_ll[1:]
if header is None:
return ""
for item in hl_context + ll_context:
content_dict = {k: v for k, v in item.items() if k != "id"}
content_key = tuple(sorted(content_dict.items()))
if content_key not in seen_content:
seen_content[content_key] = item
combined_data.append(item)
if list_hl:
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [",".join(item[1:]) for item in list_ll if item]
for i, item in enumerate(combined_data):
item["id"] = str(i)
combined_sources = []
seen = set()
for item in list_hl + list_ll:
if item and item not in seen:
combined_sources.append(item)
seen.add(item)
combined_sources_result = [",\t".join(header)]
for i, item in enumerate(combined_sources, start=1):
combined_sources_result.append(f"{i},\t{item}")
combined_sources_result = "\n".join(combined_sources_result)
return combined_sources_result
return combined_data
async def get_best_cached_response(