diff --git a/lightrag/utils.py b/lightrag/utils.py index 3daefb88..7b17cbb6 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,5 +1,7 @@ import asyncio import html +import io +import csv import json import logging import os @@ -7,7 +9,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union +from typing import Any, Union,List import xml.etree.ElementTree as ET import numpy as np @@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data -def list_of_list_to_csv(data: list[list]): - return "\n".join( - [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] - ) +# def list_of_list_to_csv(data: list[list]): +# return "\n".join( +# [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] +# ) +def list_of_list_to_csv(data: List[List[str]]) -> str: + output = io.StringIO() + writer = csv.writer(output) + writer.writerows(data) + return output.getvalue() +def csv_string_to_list(csv_string: str) -> List[List[str]]: + output = io.StringIO(csv_string) + reader = csv.reader(output) + return [row for row in reader] + + def save_data_to_file(data, file_name): @@ -248,8 +261,8 @@ def xml_to_json(xml_file): #混合检索中的合并函数 def process_combine_contexts(hl, ll): header = None - list_hl = hl.strip().split("\n") - list_ll = ll.strip().split("\n") + list_hl = csv_string_to_list(hl.strip()) + list_ll = csv_string_to_list(ll.strip()) # 去掉第一个元素(如果不为空) if list_hl: header=list_hl[0] @@ -259,12 +272,11 @@ def process_combine_contexts(hl, ll): list_ll = list_ll[1:] if header is None: return "" - - # 去掉每个子元素中逗号分隔后的第一个元素(如果不为空) + # 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重 if list_hl: - list_hl = [','.join(item.split(',')[1:]) for item in list_hl if item] + list_hl = [','.join(item[1:]) for item in list_hl if item] if list_ll: - list_ll = [','.join(item.split(',')[1:]) for item in list_ll if item] + list_ll = [','.join(item[1:]) for item in list_ll if item] # 合并并去重 combined_sources_set = set( @@ -272,12 +284,12 @@ def process_combine_contexts(hl, ll): ) # 创建包含头部的新列表 - combined_sources = [header] + combined_sources = [",\t".join(header)] # 为 combined_sources_set 中的每个元素添加自增数字 for i, item in enumerate(combined_sources_set, start=1): combined_sources.append(f"{i},\t{item}") - - # 将列表转换为字符串,元素之间用换行符分隔 + + # 将列表转换为字符串,子元素之间用换行符分隔 combined_sources = "\n".join(combined_sources) return combined_sources