Update utils.py

This commit is contained in:
gogoswift
2024-10-31 14:31:26 +08:00
committed by GitHub
parent 00d570e509
commit 7d884b9783

View File

@@ -1,5 +1,7 @@
import asyncio
import html
import io
import csv
import json
import logging
import os
@@ -7,7 +9,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union
from typing import Any, Union,List
import xml.etree.ElementTree as ET
import numpy as np
@@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
)
# def list_of_list_to_csv(data: list[list]):
# return "\n".join(
# [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
# )
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name):
@@ -248,8 +261,8 @@ def xml_to_json(xml_file):
#混合检索中的合并函数
def process_combine_contexts(hl, ll):
header = None
list_hl = hl.strip().split("\n")
list_ll = ll.strip().split("\n")
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
# 去掉第一个元素(如果不为空)
if list_hl:
header=list_hl[0]
@@ -259,12 +272,11 @@ def process_combine_contexts(hl, ll):
list_ll = list_ll[1:]
if header is None:
return ""
# 去掉每个子元素中逗号分隔后的第一个元素(如果不为空)
# 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重
if list_hl:
list_hl = [','.join(item.split(',')[1:]) for item in list_hl if item]
list_hl = [','.join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [','.join(item.split(',')[1:]) for item in list_ll if item]
list_ll = [','.join(item[1:]) for item in list_ll if item]
# 合并并去重
combined_sources_set = set(
@@ -272,12 +284,12 @@ def process_combine_contexts(hl, ll):
)
# 创建包含头部的新列表
combined_sources = [header]
combined_sources = [",\t".join(header)]
# 为 combined_sources_set 中的每个元素添加自增数字
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
# 将列表转换为字符串,元素之间用换行符分隔
# 将列表转换为字符串,元素之间用换行符分隔
combined_sources = "\n".join(combined_sources)
return combined_sources