Merge branch 'HKUDS:main' into main

This commit is contained in:
jin
2025-01-16 09:59:27 +08:00
committed by GitHub
19 changed files with 2609 additions and 2396 deletions

View File

@@ -4,7 +4,6 @@ import re
from tqdm.asyncio import tqdm as tqdm_async
from typing import Union
from collections import Counter, defaultdict
import warnings
from .utils import (
logger,
clean_str,
@@ -34,23 +33,61 @@ import time
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
content: str,
split_by_character=None,
split_by_character_only=False,
overlap_token_size=128,
max_token_size=1024,
tiktoken_model="gpt-4o",
**kwargs,
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
)
else:
new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
"tokens": _len,
"content": chunk.strip(),
"chunk_order_index": index,
}
)
else:
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
return results
@@ -582,15 +619,22 @@ async def kg_query(
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
logger.warning(
"low_level_keywords is empty, switching from %s mode to global mode",
query_param.mode,
)
query_param.mode = "global"
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
logger.warning(
"high_level_keywords is empty, switching from %s mode to local mode",
query_param.mode,
)
query_param.mode = "local"
ll_keywords = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords = ", ".join(hl_keywords) if hl_keywords else ""
logger.info("Using %s mode for query processing", query_param.mode)
# Build context
keywords = [ll_keywords, hl_keywords]
@@ -656,78 +700,52 @@ async def _build_query_context(
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context, ll_relations_context, ll_text_units_context = (
"",
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
query_param.mode = "global"
else:
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context, hl_relations_context, hl_text_units_context = (
"",
"",
"",
)
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
query_param.mode = "local"
else:
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
if (
hl_entities_context == ""
and hl_relations_context == ""
and hl_text_units_context == ""
):
logger.warn("No high level context found. Switching to local mode.")
query_param.mode = "local"
if query_param.mode == "hybrid":
ll_keywords, hl_keywords = query[0], query[1]
if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
else: # hybrid mode
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
)
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
return f"""
-----Entities-----
```csv