diff --git a/lightrag/operate.py b/lightrag/operate.py index 8a6820f5..ef0d3398 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -15,6 +15,7 @@ from .utils import ( pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, + process_combine_contexts, ) from .base import ( BaseGraphStorage, @@ -1003,35 +1004,28 @@ def combine_contexts(high_level_context, low_level_context): ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) # Combine and deduplicate the entities - combined_entities_set = set( - filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n")) - ) - combined_entities = "\n".join(combined_entities_set) - + combined_entities = process_combine_contexts(hl_entities, ll_entities) + # Combine and deduplicate the relationships - combined_relationships_set = set( - filter( - None, - hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"), - ) - ) - combined_relationships = "\n".join(combined_relationships_set) + combined_relationships = process_combine_contexts(hl_relationships, ll_relationships) # Combine and deduplicate the sources - combined_sources_set = set( - filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n")) - ) - combined_sources = "\n".join(combined_sources_set) + combined_sources = process_combine_contexts(hl_sources, ll_sources) # Format the combined context return f""" -----Entities----- ```csv {combined_entities} +``` -----Relationships----- +```csv {combined_relationships} +``` -----Sources----- +```csv {combined_sources} +`` """ diff --git a/lightrag/utils.py b/lightrag/utils.py index 0da4a51a..7b17cbb6 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,5 +1,7 @@ import asyncio import html +import io +import csv import json import logging import os @@ -7,7 +9,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union +from typing import Any, Union,List import xml.etree.ElementTree as ET import numpy as np @@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data -def list_of_list_to_csv(data: list[list]): - return "\n".join( - [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] - ) +# def list_of_list_to_csv(data: list[list]): +# return "\n".join( +# [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] +# ) +def list_of_list_to_csv(data: List[List[str]]) -> str: + output = io.StringIO() + writer = csv.writer(output) + writer.writerows(data) + return output.getvalue() +def csv_string_to_list(csv_string: str) -> List[List[str]]: + output = io.StringIO(csv_string) + reader = csv.reader(output) + return [row for row in reader] + + def save_data_to_file(data, file_name): @@ -244,3 +257,39 @@ def xml_to_json(xml_file): except Exception as e: print(f"An error occurred: {e}") return None + +#混合检索中的合并函数 +def process_combine_contexts(hl, ll): + header = None + list_hl = csv_string_to_list(hl.strip()) + list_ll = csv_string_to_list(ll.strip()) + # 去掉第一个元素(如果不为空) + if list_hl: + header=list_hl[0] + list_hl = list_hl[1:] + if list_ll: + header = list_ll[0] + list_ll = list_ll[1:] + if header is None: + return "" + # 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重 + if list_hl: + list_hl = [','.join(item[1:]) for item in list_hl if item] + if list_ll: + list_ll = [','.join(item[1:]) for item in list_ll if item] + + # 合并并去重 + combined_sources_set = set( + filter(None, list_hl + list_ll) + ) + + # 创建包含头部的新列表 + combined_sources = [",\t".join(header)] + # 为 combined_sources_set 中的每个元素添加自增数字 + for i, item in enumerate(combined_sources_set, start=1): + combined_sources.append(f"{i},\t{item}") + + # 将列表转换为字符串,子元素之间用换行符分隔 + combined_sources = "\n".join(combined_sources) + + return combined_sources