Refactor context handling to convert data from CSV to JSON format for improved compatibility with LLM, replacing the list_of_list_to_csv function with list_of_list_to_json
This commit is contained in:
@@ -15,7 +15,6 @@ from .utils import (
|
|||||||
decode_tokens_by_tiktoken,
|
decode_tokens_by_tiktoken,
|
||||||
encode_string_by_tiktoken,
|
encode_string_by_tiktoken,
|
||||||
is_float_regex,
|
is_float_regex,
|
||||||
list_of_list_to_csv,
|
|
||||||
normalize_extracted_info,
|
normalize_extracted_info,
|
||||||
pack_user_ass_to_openai_messages,
|
pack_user_ass_to_openai_messages,
|
||||||
split_string_by_multi_markers,
|
split_string_by_multi_markers,
|
||||||
@@ -27,6 +26,7 @@ from .utils import (
|
|||||||
CacheData,
|
CacheData,
|
||||||
get_conversation_turns,
|
get_conversation_turns,
|
||||||
use_llm_func_with_cache,
|
use_llm_func_with_cache,
|
||||||
|
list_of_list_to_json,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -1311,21 +1311,26 @@ async def _build_query_context(
|
|||||||
[hl_text_units_context, ll_text_units_context],
|
[hl_text_units_context, ll_text_units_context],
|
||||||
)
|
)
|
||||||
# not necessary to use LLM to generate a response
|
# not necessary to use LLM to generate a response
|
||||||
if not entities_context.strip() and not relations_context.strip():
|
if not entities_context and not relations_context:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# 转换为 JSON 字符串
|
||||||
|
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||||||
|
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||||||
|
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||||
|
|
||||||
result = f"""
|
result = f"""
|
||||||
-----Entities-----
|
-----Entities-----
|
||||||
```json
|
```json
|
||||||
{entities_context}
|
{entities_str}
|
||||||
```
|
```
|
||||||
-----Relationships-----
|
-----Relationships-----
|
||||||
```json
|
```json
|
||||||
{relations_context}
|
{relations_str}
|
||||||
```
|
```
|
||||||
-----Sources-----
|
-----Sources-----
|
||||||
```json
|
```json
|
||||||
{text_units_context}
|
{text_units_str}
|
||||||
```
|
```
|
||||||
""".strip()
|
""".strip()
|
||||||
return result
|
return result
|
||||||
@@ -1424,7 +1429,7 @@ async def _get_node_data(
|
|||||||
file_path,
|
file_path,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
entities_context = list_of_list_to_csv(entites_section_list)
|
entities_context = list_of_list_to_json(entites_section_list)
|
||||||
|
|
||||||
relations_section_list = [
|
relations_section_list = [
|
||||||
[
|
[
|
||||||
@@ -1461,14 +1466,14 @@ async def _get_node_data(
|
|||||||
file_path,
|
file_path,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
relations_context = list_of_list_to_csv(relations_section_list)
|
relations_context = list_of_list_to_json(relations_section_list)
|
||||||
|
|
||||||
text_units_section_list = [["id", "content", "file_path"]]
|
text_units_section_list = [["id", "content", "file_path"]]
|
||||||
for i, t in enumerate(use_text_units):
|
for i, t in enumerate(use_text_units):
|
||||||
text_units_section_list.append(
|
text_units_section_list.append(
|
||||||
[i, t["content"], t.get("file_path", "unknown_source")]
|
[i, t["content"], t.get("file_path", "unknown_source")]
|
||||||
)
|
)
|
||||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
text_units_context = list_of_list_to_json(text_units_section_list)
|
||||||
return entities_context, relations_context, text_units_context
|
return entities_context, relations_context, text_units_context
|
||||||
|
|
||||||
|
|
||||||
@@ -1736,7 +1741,7 @@ async def _get_edge_data(
|
|||||||
file_path,
|
file_path,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
relations_context = list_of_list_to_csv(relations_section_list)
|
relations_context = list_of_list_to_json(relations_section_list)
|
||||||
|
|
||||||
entites_section_list = [
|
entites_section_list = [
|
||||||
["id", "entity", "type", "description", "rank", "created_at", "file_path"]
|
["id", "entity", "type", "description", "rank", "created_at", "file_path"]
|
||||||
@@ -1761,12 +1766,12 @@ async def _get_edge_data(
|
|||||||
file_path,
|
file_path,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
entities_context = list_of_list_to_csv(entites_section_list)
|
entities_context = list_of_list_to_json(entites_section_list)
|
||||||
|
|
||||||
text_units_section_list = [["id", "content", "file_path"]]
|
text_units_section_list = [["id", "content", "file_path"]]
|
||||||
for i, t in enumerate(use_text_units):
|
for i, t in enumerate(use_text_units):
|
||||||
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
|
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
|
||||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
text_units_context = list_of_list_to_json(text_units_section_list)
|
||||||
return entities_context, relations_context, text_units_context
|
return entities_context, relations_context, text_units_context
|
||||||
|
|
||||||
|
|
||||||
|
@@ -374,37 +374,24 @@ def truncate_list_by_token_size(
|
|||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
def list_of_list_to_json(data: list[list[str]]) -> list[dict[str, str]]:
|
||||||
output = io.StringIO()
|
if not data or len(data) <= 1:
|
||||||
writer = csv.writer(
|
return []
|
||||||
output,
|
|
||||||
quoting=csv.QUOTE_ALL, # Quote all fields
|
|
||||||
escapechar="\\", # Use backslash as escape character
|
|
||||||
quotechar='"', # Use double quotes
|
|
||||||
lineterminator="\n", # Explicit line terminator
|
|
||||||
)
|
|
||||||
writer.writerows(data)
|
|
||||||
return output.getvalue()
|
|
||||||
|
|
||||||
|
header = data[0]
|
||||||
|
result = []
|
||||||
|
|
||||||
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
for row in data[1:]:
|
||||||
# Clean the string by removing NUL characters
|
if len(row) >= 2:
|
||||||
cleaned_string = csv_string.replace("\0", "")
|
item = {}
|
||||||
|
for i, field_name in enumerate(header):
|
||||||
|
if i < len(row):
|
||||||
|
item[field_name] = row[i]
|
||||||
|
else:
|
||||||
|
item[field_name] = ""
|
||||||
|
result.append(item)
|
||||||
|
|
||||||
output = io.StringIO(cleaned_string)
|
return result
|
||||||
reader = csv.reader(
|
|
||||||
output,
|
|
||||||
quoting=csv.QUOTE_ALL, # Match the writer configuration
|
|
||||||
escapechar="\\", # Use backslash as escape character
|
|
||||||
quotechar='"', # Use double quotes
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return [row for row in reader]
|
|
||||||
except csv.Error as e:
|
|
||||||
raise ValueError(f"Failed to parse CSV string: {str(e)}")
|
|
||||||
finally:
|
|
||||||
output.close()
|
|
||||||
|
|
||||||
|
|
||||||
def save_data_to_file(data, file_name):
|
def save_data_to_file(data, file_name):
|
||||||
@@ -472,50 +459,21 @@ def xml_to_json(xml_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_combine_contexts(hl: str, ll: str):
|
def process_combine_contexts(hl_context: dict, ll_context: dict):
|
||||||
list_hl = csv_string_to_list(hl.strip()) if hl.strip() else []
|
seen_content = {}
|
||||||
list_ll = csv_string_to_list(ll.strip()) if ll.strip() else []
|
|
||||||
|
|
||||||
if not list_hl and not list_ll:
|
|
||||||
return json.dumps([], ensure_ascii=False)
|
|
||||||
|
|
||||||
header = None
|
|
||||||
if list_hl and len(list_hl) > 0:
|
|
||||||
header = list_hl[0]
|
|
||||||
list_hl = list_hl[1:]
|
|
||||||
if list_ll and len(list_ll) > 0:
|
|
||||||
if header is None:
|
|
||||||
header = list_ll[0]
|
|
||||||
list_ll = list_ll[1:]
|
|
||||||
|
|
||||||
if header is None:
|
|
||||||
return json.dumps([], ensure_ascii=False)
|
|
||||||
|
|
||||||
combined_data = []
|
combined_data = []
|
||||||
seen = set()
|
|
||||||
|
|
||||||
def process_row(row):
|
for item in hl_context + ll_context:
|
||||||
if len(row) < 2:
|
content_key = {k: v for k, v in item.items() if k != 'id'}
|
||||||
return None
|
content_key_str = str(content_key)
|
||||||
|
if content_key_str not in seen_content:
|
||||||
item_data = {}
|
seen_content[content_key_str] = item
|
||||||
|
|
||||||
for i, field_name in enumerate(header):
|
|
||||||
item_data[field_name] = row[i]
|
|
||||||
|
|
||||||
return item_data
|
|
||||||
|
|
||||||
for row in list_hl + list_ll:
|
|
||||||
if len(row) >= 2:
|
|
||||||
row_identifier = json.dumps(row[1:], ensure_ascii=False)
|
|
||||||
|
|
||||||
if row_identifier not in seen:
|
|
||||||
seen.add(row_identifier)
|
|
||||||
item = process_row(row)
|
|
||||||
if item:
|
|
||||||
combined_data.append(item)
|
combined_data.append(item)
|
||||||
|
|
||||||
return json.dumps(combined_data, ensure_ascii=False)
|
for i, item in enumerate(combined_data):
|
||||||
|
item['id'] = i
|
||||||
|
|
||||||
|
return combined_data
|
||||||
|
|
||||||
|
|
||||||
async def get_best_cached_response(
|
async def get_best_cached_response(
|
||||||
|
Reference in New Issue
Block a user