Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -15,6 +15,7 @@ from .utils import (
|
|||||||
pack_user_ass_to_openai_messages,
|
pack_user_ass_to_openai_messages,
|
||||||
split_string_by_multi_markers,
|
split_string_by_multi_markers,
|
||||||
truncate_list_by_token_size,
|
truncate_list_by_token_size,
|
||||||
|
process_combine_contexts,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -1006,35 +1007,28 @@ def combine_contexts(high_level_context, low_level_context):
|
|||||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||||
|
|
||||||
# Combine and deduplicate the entities
|
# Combine and deduplicate the entities
|
||||||
combined_entities_set = set(
|
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
||||||
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
|
||||||
)
|
|
||||||
combined_entities = "\n".join(combined_entities_set)
|
|
||||||
|
|
||||||
# Combine and deduplicate the relationships
|
# Combine and deduplicate the relationships
|
||||||
combined_relationships_set = set(
|
combined_relationships = process_combine_contexts(hl_relationships, ll_relationships)
|
||||||
filter(
|
|
||||||
None,
|
|
||||||
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
combined_relationships = "\n".join(combined_relationships_set)
|
|
||||||
|
|
||||||
# Combine and deduplicate the sources
|
# Combine and deduplicate the sources
|
||||||
combined_sources_set = set(
|
combined_sources = process_combine_contexts(hl_sources, ll_sources)
|
||||||
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
|
||||||
)
|
|
||||||
combined_sources = "\n".join(combined_sources_set)
|
|
||||||
|
|
||||||
# Format the combined context
|
# Format the combined context
|
||||||
return f"""
|
return f"""
|
||||||
-----Entities-----
|
-----Entities-----
|
||||||
```csv
|
```csv
|
||||||
{combined_entities}
|
{combined_entities}
|
||||||
|
```
|
||||||
-----Relationships-----
|
-----Relationships-----
|
||||||
|
```csv
|
||||||
{combined_relationships}
|
{combined_relationships}
|
||||||
|
```
|
||||||
-----Sources-----
|
-----Sources-----
|
||||||
|
```csv
|
||||||
{combined_sources}
|
{combined_sources}
|
||||||
|
``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
|
import io
|
||||||
|
import csv
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -7,7 +9,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union
|
from typing import Any, Union,List
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -174,11 +176,17 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|||||||
return list_data[:i]
|
return list_data[:i]
|
||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
|
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||||
|
output = io.StringIO()
|
||||||
|
writer = csv.writer(output)
|
||||||
|
writer.writerows(data)
|
||||||
|
return output.getvalue()
|
||||||
|
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
||||||
|
output = io.StringIO(csv_string)
|
||||||
|
reader = csv.reader(output)
|
||||||
|
return [row for row in reader]
|
||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: list[list]):
|
|
||||||
return "\n".join(
|
|
||||||
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_data_to_file(data, file_name):
|
def save_data_to_file(data, file_name):
|
||||||
@@ -244,3 +252,35 @@ def xml_to_json(xml_file):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
print(f"An error occurred: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def process_combine_contexts(hl, ll):
|
||||||
|
header = None
|
||||||
|
list_hl = csv_string_to_list(hl.strip())
|
||||||
|
list_ll = csv_string_to_list(ll.strip())
|
||||||
|
|
||||||
|
if list_hl:
|
||||||
|
header=list_hl[0]
|
||||||
|
list_hl = list_hl[1:]
|
||||||
|
if list_ll:
|
||||||
|
header = list_ll[0]
|
||||||
|
list_ll = list_ll[1:]
|
||||||
|
if header is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if list_hl:
|
||||||
|
list_hl = [','.join(item[1:]) for item in list_hl if item]
|
||||||
|
if list_ll:
|
||||||
|
list_ll = [','.join(item[1:]) for item in list_ll if item]
|
||||||
|
|
||||||
|
combined_sources_set = set(
|
||||||
|
filter(None, list_hl + list_ll)
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_sources = [",\t".join(header)]
|
||||||
|
|
||||||
|
for i, item in enumerate(combined_sources_set, start=1):
|
||||||
|
combined_sources.append(f"{i},\t{item}")
|
||||||
|
|
||||||
|
combined_sources = "\n".join(combined_sources)
|
||||||
|
|
||||||
|
return combined_sources
|
||||||
|
Reference in New Issue
Block a user