This commit is contained in:
Larfii
2024-10-08 10:38:50 +08:00
parent 286e5319b9
commit 44463503fd
11 changed files with 90 additions and 18 deletions

View File

@@ -176,7 +176,6 @@ async def _merge_edges_then_upsert(
already_weights = []
already_source_ids = []
already_description = []
##################
already_keywords = []
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
@@ -186,7 +185,6 @@ async def _merge_edges_then_upsert(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
############
already_keywords.extend(
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
)
@@ -195,7 +193,6 @@ async def _merge_edges_then_upsert(
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
)
##########
keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
)
@@ -403,7 +400,7 @@ async def local_query(
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
context = await _build_local_query_context(
keywords,
@@ -415,7 +412,7 @@ async def local_query(
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -424,7 +421,7 @@ async def local_query(
query,
system_prompt=sys_prompt,
)
return response, context
return response
async def _build_local_query_context(
query,
@@ -622,7 +619,7 @@ async def global_query(
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
context = await _build_global_query_context(
keywords,
@@ -636,7 +633,7 @@ async def global_query(
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
@@ -646,7 +643,7 @@ async def global_query(
query,
system_prompt=sys_prompt,
)
return (response, context)
return response
async def _build_global_query_context(
keywords,
@@ -836,7 +833,7 @@ async def hybird_query(
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
low_level_context = await _build_local_query_context(
ll_keywords,
@@ -860,7 +857,7 @@ async def hybird_query(
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
@@ -870,7 +867,7 @@ async def hybird_query(
query,
system_prompt=sys_prompt,
)
return (response, context)
return response
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings
@@ -922,14 +919,14 @@ async def naive_query(
use_model_func = global_config["llm_model_func"]
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"], "None"
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.naive_max_token_for_text_unit,
max_token_size=query_param.max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])