Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -4,7 +4,6 @@ import re
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from typing import Union
|
||||
from collections import Counter, defaultdict
|
||||
import warnings
|
||||
from .utils import (
|
||||
logger,
|
||||
clean_str,
|
||||
@@ -34,23 +33,61 @@ import time
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
content: str,
|
||||
split_by_character=None,
|
||||
split_by_character_only=False,
|
||||
overlap_token_size=128,
|
||||
max_token_size=1024,
|
||||
tiktoken_model="gpt-4o",
|
||||
**kwargs,
|
||||
):
|
||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||
results = []
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
"content": chunk_content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
if split_by_character:
|
||||
raw_chunks = content.split(split_by_character)
|
||||
new_chunks = []
|
||||
if split_by_character_only:
|
||||
for chunk in raw_chunks:
|
||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
else:
|
||||
for chunk in raw_chunks:
|
||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||
if len(_tokens) > max_token_size:
|
||||
for start in range(
|
||||
0, len(_tokens), max_token_size - overlap_token_size
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
_tokens[start : start + max_token_size],
|
||||
model_name=tiktoken_model,
|
||||
)
|
||||
new_chunks.append(
|
||||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
||||
)
|
||||
else:
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
for index, (_len, chunk) in enumerate(new_chunks):
|
||||
results.append(
|
||||
{
|
||||
"tokens": _len,
|
||||
"content": chunk.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
else:
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
"content": chunk_content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@@ -582,15 +619,22 @@ async def kg_query(
|
||||
logger.warning("low_level_keywords and high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
|
||||
logger.warning("low_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
logger.warning(
|
||||
"low_level_keywords is empty, switching from %s mode to global mode",
|
||||
query_param.mode,
|
||||
)
|
||||
query_param.mode = "global"
|
||||
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
|
||||
logger.warning("high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
logger.warning(
|
||||
"high_level_keywords is empty, switching from %s mode to local mode",
|
||||
query_param.mode,
|
||||
)
|
||||
query_param.mode = "local"
|
||||
|
||||
ll_keywords = ", ".join(ll_keywords) if ll_keywords else ""
|
||||
hl_keywords = ", ".join(hl_keywords) if hl_keywords else ""
|
||||
|
||||
logger.info("Using %s mode for query processing", query_param.mode)
|
||||
|
||||
# Build context
|
||||
keywords = [ll_keywords, hl_keywords]
|
||||
@@ -656,78 +700,52 @@ async def _build_query_context(
|
||||
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
||||
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
|
||||
|
||||
ll_kewwords, hl_keywrds = query[0], query[1]
|
||||
if query_param.mode in ["local", "hybrid"]:
|
||||
if ll_kewwords == "":
|
||||
ll_entities_context, ll_relations_context, ll_text_units_context = (
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
warnings.warn(
|
||||
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||
)
|
||||
query_param.mode = "global"
|
||||
else:
|
||||
(
|
||||
ll_entities_context,
|
||||
ll_relations_context,
|
||||
ll_text_units_context,
|
||||
) = await _get_node_data(
|
||||
ll_kewwords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
if query_param.mode in ["global", "hybrid"]:
|
||||
if hl_keywrds == "":
|
||||
hl_entities_context, hl_relations_context, hl_text_units_context = (
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
warnings.warn(
|
||||
"High Level context is None. Return empty High entity/relationship/source"
|
||||
)
|
||||
query_param.mode = "local"
|
||||
else:
|
||||
(
|
||||
hl_entities_context,
|
||||
hl_relations_context,
|
||||
hl_text_units_context,
|
||||
) = await _get_edge_data(
|
||||
hl_keywrds,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
if (
|
||||
hl_entities_context == ""
|
||||
and hl_relations_context == ""
|
||||
and hl_text_units_context == ""
|
||||
):
|
||||
logger.warn("No high level context found. Switching to local mode.")
|
||||
query_param.mode = "local"
|
||||
if query_param.mode == "hybrid":
|
||||
ll_keywords, hl_keywords = query[0], query[1]
|
||||
|
||||
if query_param.mode == "local":
|
||||
entities_context, relations_context, text_units_context = await _get_node_data(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
elif query_param.mode == "global":
|
||||
entities_context, relations_context, text_units_context = await _get_edge_data(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
else: # hybrid mode
|
||||
(
|
||||
ll_entities_context,
|
||||
ll_relations_context,
|
||||
ll_text_units_context,
|
||||
) = await _get_node_data(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
(
|
||||
hl_entities_context,
|
||||
hl_relations_context,
|
||||
hl_text_units_context,
|
||||
) = await _get_edge_data(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
entities_context, relations_context, text_units_context = combine_contexts(
|
||||
[hl_entities_context, ll_entities_context],
|
||||
[hl_relations_context, ll_relations_context],
|
||||
[hl_text_units_context, ll_text_units_context],
|
||||
)
|
||||
elif query_param.mode == "local":
|
||||
entities_context, relations_context, text_units_context = (
|
||||
ll_entities_context,
|
||||
ll_relations_context,
|
||||
ll_text_units_context,
|
||||
)
|
||||
elif query_param.mode == "global":
|
||||
entities_context, relations_context, text_units_context = (
|
||||
hl_entities_context,
|
||||
hl_relations_context,
|
||||
hl_text_units_context,
|
||||
)
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
|
Reference in New Issue
Block a user