Logic Optimization
This commit is contained in:
@@ -249,13 +249,17 @@ async def extract_entities(
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
# add language and example number params to prompt
|
||||
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
|
||||
language = global_config["addon_params"].get(
|
||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
)
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
|
||||
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
|
||||
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
||||
examples = "\n".join(
|
||||
PROMPTS["entity_extraction_examples"][: int(example_number)]
|
||||
)
|
||||
else:
|
||||
examples="\n".join(PROMPTS["entity_extraction_examples"])
|
||||
|
||||
examples = "\n".join(PROMPTS["entity_extraction_examples"])
|
||||
|
||||
entity_extract_prompt = PROMPTS["entity_extraction"]
|
||||
context_base = dict(
|
||||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||||
@@ -263,8 +267,9 @@ async def extract_entities(
|
||||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||||
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
||||
examples=examples,
|
||||
language=language)
|
||||
|
||||
language=language,
|
||||
)
|
||||
|
||||
continue_prompt = PROMPTS["entiti_continue_extraction"]
|
||||
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
||||
|
||||
@@ -396,6 +401,7 @@ async def extract_entities(
|
||||
|
||||
return knowledge_graph_inst
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -408,59 +414,61 @@ async def kg_query(
|
||||
context = None
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
||||
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
|
||||
examples = "\n".join(
|
||||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||||
)
|
||||
else:
|
||||
examples="\n".join(PROMPTS["keywords_extraction_examples"])
|
||||
|
||||
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
||||
|
||||
# Set mode
|
||||
if query_param.mode not in ["local", "global", "hybrid"]:
|
||||
logger.error(f"Unknown mode {query_param.mode} in kg_query")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
|
||||
# LLM generate keywords
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
|
||||
result = await use_model_func(kw_prompt)
|
||||
logger.info(f"kw_prompt result:")
|
||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples)
|
||||
result = await use_model_func(kw_prompt)
|
||||
logger.info("kw_prompt result:")
|
||||
print(result)
|
||||
try:
|
||||
json_text = locate_json_string_body_from_string(result)
|
||||
keywords_data = json.loads(json_text)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e} {result}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
|
||||
# Handdle keywords missing
|
||||
if hl_keywords == [] and ll_keywords == []:
|
||||
logger.warning("low_level_keywords and high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
|
||||
return PROMPTS["fail_response"]
|
||||
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
|
||||
logger.warning("low_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
|
||||
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
|
||||
logger.warning("high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
|
||||
|
||||
# Build context
|
||||
keywords = [ll_keywords, hl_keywords]
|
||||
keywords = [ll_keywords, hl_keywords]
|
||||
context = await _build_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
@@ -468,13 +476,13 @@ async def kg_query(
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
)
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
@@ -496,44 +504,72 @@ async def _build_query_context(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
):
|
||||
ll_kewwords, hl_keywrds = query[0], query[1]
|
||||
if query_param.mode in ["local", "hybrid"]:
|
||||
if ll_kewwords == "":
|
||||
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
|
||||
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
||||
ll_entities_context, ll_relations_context, ll_text_units_context = (
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
warnings.warn(
|
||||
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||
)
|
||||
query_param.mode = "global"
|
||||
else:
|
||||
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
|
||||
(
|
||||
ll_entities_context,
|
||||
ll_relations_context,
|
||||
ll_text_units_context,
|
||||
) = await _get_node_data(
|
||||
ll_kewwords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param
|
||||
)
|
||||
query_param,
|
||||
)
|
||||
if query_param.mode in ["global", "hybrid"]:
|
||||
if hl_keywrds == "":
|
||||
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
|
||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
||||
hl_entities_context, hl_relations_context, hl_text_units_context = (
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
warnings.warn(
|
||||
"High Level context is None. Return empty High entity/relationship/source"
|
||||
)
|
||||
query_param.mode = "local"
|
||||
else:
|
||||
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
|
||||
(
|
||||
hl_entities_context,
|
||||
hl_relations_context,
|
||||
hl_text_units_context,
|
||||
) = await _get_edge_data(
|
||||
hl_keywrds,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param
|
||||
)
|
||||
if query_param.mode == 'hybrid':
|
||||
entities_context,relations_context,text_units_context = combine_contexts(
|
||||
[hl_entities_context,ll_entities_context],
|
||||
[hl_relations_context,ll_relations_context],
|
||||
[hl_text_units_context,ll_text_units_context]
|
||||
)
|
||||
elif query_param.mode == 'local':
|
||||
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
|
||||
elif query_param.mode == 'global':
|
||||
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
|
||||
query_param,
|
||||
)
|
||||
if query_param.mode == "hybrid":
|
||||
entities_context, relations_context, text_units_context = combine_contexts(
|
||||
[hl_entities_context, ll_entities_context],
|
||||
[hl_relations_context, ll_relations_context],
|
||||
[hl_text_units_context, ll_text_units_context],
|
||||
)
|
||||
elif query_param.mode == "local":
|
||||
entities_context, relations_context, text_units_context = (
|
||||
ll_entities_context,
|
||||
ll_relations_context,
|
||||
ll_text_units_context,
|
||||
)
|
||||
elif query_param.mode == "global":
|
||||
entities_context, relations_context, text_units_context = (
|
||||
hl_entities_context,
|
||||
hl_relations_context,
|
||||
hl_text_units_context,
|
||||
)
|
||||
return f"""
|
||||
# -----Entities-----
|
||||
# ```csv
|
||||
@@ -550,7 +586,6 @@ async def _build_query_context(
|
||||
# """
|
||||
|
||||
|
||||
|
||||
async def _get_node_data(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -568,7 +603,7 @@ async def _get_node_data(
|
||||
)
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
|
||||
|
||||
# 获取实体的度
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
@@ -588,7 +623,7 @@ async def _get_node_data(
|
||||
)
|
||||
logger.info(
|
||||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
||||
)
|
||||
)
|
||||
|
||||
# 构建提示词
|
||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||
@@ -625,7 +660,7 @@ async def _get_node_data(
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context,relations_context,text_units_context
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
@@ -821,8 +856,7 @@ async def _get_edge_data(
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context,relations_context,text_units_context
|
||||
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
@@ -902,7 +936,7 @@ async def _find_related_text_unit_from_relationships(
|
||||
def combine_contexts(entities, relationships, sources):
|
||||
# Function to extract entities, relationships, and sources from context strings
|
||||
hl_entities, ll_entities = entities[0], entities[1]
|
||||
hl_relationships, ll_relationships = relationships[0],relationships[1]
|
||||
hl_relationships, ll_relationships = relationships[0], relationships[1]
|
||||
hl_sources, ll_sources = sources[0], sources[1]
|
||||
# Combine and deduplicate the entities
|
||||
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
||||
|
Reference in New Issue
Block a user