Update utils.py

This commit is contained in:
gogoswift
2024-10-31 14:31:26 +08:00
committed by GitHub
parent 00d570e509
commit 7d884b9783

View File

@@ -1,5 +1,7 @@
import asyncio import asyncio
import html import html
import io
import csv
import json import json
import logging import logging
import os import os
@@ -7,7 +9,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union from typing import Any, Union,List
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
@@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data return list_data
def list_of_list_to_csv(data: list[list]): # def list_of_list_to_csv(data: list[list]):
return "\n".join( # return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] # [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
) # )
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name): def save_data_to_file(data, file_name):
@@ -248,8 +261,8 @@ def xml_to_json(xml_file):
#混合检索中的合并函数 #混合检索中的合并函数
def process_combine_contexts(hl, ll): def process_combine_contexts(hl, ll):
header = None header = None
list_hl = hl.strip().split("\n") list_hl = csv_string_to_list(hl.strip())
list_ll = ll.strip().split("\n") list_ll = csv_string_to_list(ll.strip())
# 去掉第一个元素(如果不为空) # 去掉第一个元素(如果不为空)
if list_hl: if list_hl:
header=list_hl[0] header=list_hl[0]
@@ -259,12 +272,11 @@ def process_combine_contexts(hl, ll):
list_ll = list_ll[1:] list_ll = list_ll[1:]
if header is None: if header is None:
return "" return ""
# 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重
# 去掉每个子元素中逗号分隔后的第一个元素(如果不为空)
if list_hl: if list_hl:
list_hl = [','.join(item.split(',')[1:]) for item in list_hl if item] list_hl = [','.join(item[1:]) for item in list_hl if item]
if list_ll: if list_ll:
list_ll = [','.join(item.split(',')[1:]) for item in list_ll if item] list_ll = [','.join(item[1:]) for item in list_ll if item]
# 合并并去重 # 合并并去重
combined_sources_set = set( combined_sources_set = set(
@@ -272,12 +284,12 @@ def process_combine_contexts(hl, ll):
) )
# 创建包含头部的新列表 # 创建包含头部的新列表
combined_sources = [header] combined_sources = [",\t".join(header)]
# 为 combined_sources_set 中的每个元素添加自增数字 # 为 combined_sources_set 中的每个元素添加自增数字
for i, item in enumerate(combined_sources_set, start=1): for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}") combined_sources.append(f"{i},\t{item}")
# 将列表转换为字符串,元素之间用换行符分隔 # 将列表转换为字符串,元素之间用换行符分隔
combined_sources = "\n".join(combined_sources) combined_sources = "\n".join(combined_sources)
return combined_sources return combined_sources