chunk split retry

This commit is contained in:
童石渊
2025-01-07 16:26:12 +08:00
parent 059e3882f1
commit 6b19401dc6
3 changed files with 886 additions and 135 deletions

View File

@@ -34,7 +34,11 @@ import time
def chunking_by_token_size(
content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
content: str,
split_by_character=None,
overlap_token_size=128,
max_token_size=1024,
tiktoken_model="gpt-4o",
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
@@ -44,11 +48,16 @@ def chunking_by_token_size(
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
if len(_tokens) > max_token_size:
for start in range(0, len(_tokens), max_token_size - overlap_token_size):
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start: start + max_token_size], model_name=tiktoken_model
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
)
new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content))
else:
new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
@@ -61,10 +70,10 @@ def chunking_by_token_size(
)
else:
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start: start + max_token_size], model_name=tiktoken_model
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
@@ -77,9 +86,9 @@ def chunking_by_token_size(
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
global_config: dict,
entity_or_relation_name: str,
description: str,
global_config: dict,
) -> str:
use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"]
@@ -108,8 +117,8 @@ async def _handle_entity_relation_summary(
async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
@@ -129,8 +138,8 @@ async def _handle_single_entity_extraction(
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
@@ -156,10 +165,10 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
entity_name: str,
nodes_data: list[dict],
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entity_types = []
already_source_ids = []
@@ -203,11 +212,11 @@ async def _merge_nodes_then_upsert(
async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
src_id: str,
tgt_id: str,
edges_data: list[dict],
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_weights = []
already_source_ids = []
@@ -270,12 +279,12 @@ async def _merge_edges_then_upsert(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
llm_response_cache: BaseKVStorage = None,
chunks: dict[str, TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -327,13 +336,13 @@ async def extract_entities(
already_relations = 0
async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
):
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
@@ -435,7 +444,7 @@ async def extract_entities(
already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
@@ -445,10 +454,10 @@ async def extract_entities(
results = []
for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Extracting entities from chunks",
unit="chunk",
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Extracting entities from chunks",
unit="chunk",
):
results.append(await result)
@@ -462,32 +471,32 @@ async def extract_entities(
logger.info("Inserting entities into storage...")
all_entities_data = []
for result in tqdm_async(
asyncio.as_completed(
[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
),
total=len(maybe_nodes),
desc="Inserting entities",
unit="entity",
asyncio.as_completed(
[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
),
total=len(maybe_nodes),
desc="Inserting entities",
unit="entity",
):
all_entities_data.append(await result)
logger.info("Inserting relationships into storage...")
all_relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
[
_merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items()
]
),
total=len(maybe_edges),
desc="Inserting relationships",
unit="relationship",
asyncio.as_completed(
[
_merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items()
]
),
total=len(maybe_edges),
desc="Inserting relationships",
unit="relationship",
):
all_relationships_data.append(await result)
@@ -518,9 +527,9 @@ async def extract_entities(
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
"metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time())
},
@@ -533,14 +542,14 @@ async def extract_entities(
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -660,12 +669,12 @@ async def kg_query(
async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
@@ -718,9 +727,9 @@ async def _build_query_context(
query_param,
)
if (
hl_entities_context == ""
and hl_relations_context == ""
and hl_text_units_context == ""
hl_entities_context == ""
and hl_relations_context == ""
and hl_text_units_context == ""
):
logger.warn("No high level context found. Switching to local mode.")
query_param.mode = "local"
@@ -759,11 +768,11 @@ async def _build_query_context(
async def _get_node_data(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
# get similar entities
results = await entities_vdb.query(query, top_k=query_param.top_k)
@@ -850,10 +859,10 @@ async def _get_node_data(
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -893,8 +902,8 @@ async def _find_most_related_text_unit_from_entities(
if this_edges:
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
all_text_units_lookup[c_id]["relation_counts"] += 1
@@ -924,9 +933,9 @@ async def _find_most_related_text_unit_from_entities(
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
all_related_edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
@@ -964,11 +973,11 @@ async def _find_most_related_edges_from_entities(
async def _get_edge_data(
keywords,
knowledge_graph_inst: BaseGraphStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
keywords,
knowledge_graph_inst: BaseGraphStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
@@ -1066,9 +1075,9 @@ async def _get_edge_data(
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
edge_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
entity_names = []
seen = set()
@@ -1103,10 +1112,10 @@ async def _find_most_related_entities_from_relationships(
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1172,12 +1181,12 @@ def combine_contexts(entities, relationships, sources):
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
):
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -1235,7 +1244,7 @@ async def naive_query(
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt):]
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
@@ -1263,15 +1272,15 @@ async def naive_query(
async def mix_kg_vector_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1296,7 +1305,7 @@ async def mix_kg_vector_query(
# Reuse keyword extraction logic from kg_query
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(
PROMPTS["keywords_extraction_examples"]
PROMPTS["keywords_extraction_examples"]
):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]