diff --git a/lightrag/operate.py b/lightrag/operate.py index 518bd68a..2edeb548 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -15,6 +15,7 @@ from .utils import ( pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, + process_combine_contexts, ) from .base import ( BaseGraphStorage, @@ -1006,35 +1007,28 @@ def combine_contexts(high_level_context, low_level_context): ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) # Combine and deduplicate the entities - combined_entities_set = set( - filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n")) - ) - combined_entities = "\n".join(combined_entities_set) - + combined_entities = process_combine_contexts(hl_entities, ll_entities) + # Combine and deduplicate the relationships - combined_relationships_set = set( - filter( - None, - hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"), - ) - ) - combined_relationships = "\n".join(combined_relationships_set) + combined_relationships = process_combine_contexts(hl_relationships, ll_relationships) # Combine and deduplicate the sources - combined_sources_set = set( - filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n")) - ) - combined_sources = "\n".join(combined_sources_set) + combined_sources = process_combine_contexts(hl_sources, ll_sources) # Format the combined context return f""" -----Entities----- ```csv {combined_entities} +``` -----Relationships----- +```csv {combined_relationships} +``` -----Sources----- +```csv {combined_sources} +`` """ diff --git a/lightrag/utils.py b/lightrag/utils.py index 0da4a51a..254f5dad 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,5 +1,7 @@ import asyncio import html +import io +import csv import json import logging import os @@ -7,7 +9,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union +from typing import Any, Union,List import xml.etree.ElementTree as ET import numpy as np @@ -174,11 +176,17 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data[:i] return list_data +def list_of_list_to_csv(data: List[List[str]]) -> str: + output = io.StringIO() + writer = csv.writer(output) + writer.writerows(data) + return output.getvalue() +def csv_string_to_list(csv_string: str) -> List[List[str]]: + output = io.StringIO(csv_string) + reader = csv.reader(output) + return [row for row in reader] + -def list_of_list_to_csv(data: list[list]): - return "\n".join( - [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] - ) def save_data_to_file(data, file_name): @@ -244,3 +252,35 @@ def xml_to_json(xml_file): except Exception as e: print(f"An error occurred: {e}") return None + +def process_combine_contexts(hl, ll): + header = None + list_hl = csv_string_to_list(hl.strip()) + list_ll = csv_string_to_list(ll.strip()) + + if list_hl: + header=list_hl[0] + list_hl = list_hl[1:] + if list_ll: + header = list_ll[0] + list_ll = list_ll[1:] + if header is None: + return "" + + if list_hl: + list_hl = [','.join(item[1:]) for item in list_hl if item] + if list_ll: + list_ll = [','.join(item[1:]) for item in list_ll if item] + + combined_sources_set = set( + filter(None, list_hl + list_ll) + ) + + combined_sources = [",\t".join(header)] + + for i, item in enumerate(combined_sources_set, start=1): + combined_sources.append(f"{i},\t{item}") + + combined_sources = "\n".join(combined_sources) + + return combined_sources