Merge branch 'HKUDS:main' into main

This commit is contained in:
wiltshirek
2024-11-03 07:24:28 -05:00
committed by GitHub
2 changed files with 55 additions and 21 deletions

View File

@@ -15,6 +15,7 @@ from .utils import (
pack_user_ass_to_openai_messages, pack_user_ass_to_openai_messages,
split_string_by_multi_markers, split_string_by_multi_markers,
truncate_list_by_token_size, truncate_list_by_token_size,
process_combine_contexts,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -1006,35 +1007,28 @@ def combine_contexts(high_level_context, low_level_context):
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities # Combine and deduplicate the entities
combined_entities_set = set( combined_entities = process_combine_contexts(hl_entities, ll_entities)
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
)
combined_entities = "\n".join(combined_entities_set)
# Combine and deduplicate the relationships # Combine and deduplicate the relationships
combined_relationships_set = set( combined_relationships = process_combine_contexts(hl_relationships, ll_relationships)
filter(
None,
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
)
)
combined_relationships = "\n".join(combined_relationships_set)
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources_set = set( combined_sources = process_combine_contexts(hl_sources, ll_sources)
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
)
combined_sources = "\n".join(combined_sources_set)
# Format the combined context # Format the combined context
return f""" return f"""
-----Entities----- -----Entities-----
```csv ```csv
{combined_entities} {combined_entities}
```
-----Relationships----- -----Relationships-----
```csv
{combined_relationships} {combined_relationships}
```
-----Sources----- -----Sources-----
```csv
{combined_sources} {combined_sources}
``
""" """

View File

@@ -1,5 +1,7 @@
import asyncio import asyncio
import html import html
import io
import csv
import json import json
import logging import logging
import os import os
@@ -7,7 +9,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union from typing import Any, Union,List
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
@@ -174,11 +176,17 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i] return list_data[:i]
return list_data return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
)
def save_data_to_file(data, file_name): def save_data_to_file(data, file_name):
@@ -244,3 +252,35 @@ def xml_to_json(xml_file):
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def process_combine_contexts(hl, ll):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
if list_hl:
header=list_hl[0]
list_hl = list_hl[1:]
if list_ll:
header = list_ll[0]
list_ll = list_ll[1:]
if header is None:
return ""
if list_hl:
list_hl = [','.join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [','.join(item[1:]) for item in list_ll if item]
combined_sources_set = set(
filter(None, list_hl + list_ll)
)
combined_sources = [",\t".join(header)]
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
combined_sources = "\n".join(combined_sources)
return combined_sources