Update utils.py

This commit is contained in:
zrguo
2024-11-03 17:53:53 +08:00
committed by GitHub
parent 26c5f1b743
commit 3f7ae11962

View File

@@ -176,11 +176,6 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i] return list_data[:i]
return list_data return list_data
# def list_of_list_to_csv(data: list[list]):
# return "\n".join(
# [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
# )
def list_of_list_to_csv(data: List[List[str]]) -> str: def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
@@ -258,12 +253,11 @@ def xml_to_json(xml_file):
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
#混合检索中的合并函数
def process_combine_contexts(hl, ll): def process_combine_contexts(hl, ll):
header = None header = None
list_hl = csv_string_to_list(hl.strip()) list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip()) list_ll = csv_string_to_list(ll.strip())
# 去掉第一个元素(如果不为空)
if list_hl: if list_hl:
header=list_hl[0] header=list_hl[0]
list_hl = list_hl[1:] list_hl = list_hl[1:]
@@ -272,24 +266,21 @@ def process_combine_contexts(hl, ll):
list_ll = list_ll[1:] list_ll = list_ll[1:]
if header is None: if header is None:
return "" return ""
# 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重
if list_hl: if list_hl:
list_hl = [','.join(item[1:]) for item in list_hl if item] list_hl = [','.join(item[1:]) for item in list_hl if item]
if list_ll: if list_ll:
list_ll = [','.join(item[1:]) for item in list_ll if item] list_ll = [','.join(item[1:]) for item in list_ll if item]
# 合并并去重
combined_sources_set = set( combined_sources_set = set(
filter(None, list_hl + list_ll) filter(None, list_hl + list_ll)
) )
# 创建包含头部的新列表
combined_sources = [",\t".join(header)] combined_sources = [",\t".join(header)]
# 为 combined_sources_set 中的每个元素添加自增数字
for i, item in enumerate(combined_sources_set, start=1): for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}") combined_sources.append(f"{i},\t{item}")
# 将列表转换为字符串,子元素之间用换行符分隔
combined_sources = "\n".join(combined_sources) combined_sources = "\n".join(combined_sources)
return combined_sources return combined_sources