Logic Optimization

This commit is contained in:
jin
2024-11-25 13:40:38 +08:00
parent bf5815be8f
commit 21f161390a
8 changed files with 185 additions and 136 deletions

View File

@@ -249,13 +249,17 @@ async def extract_entities(
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
examples = "\n".join(
PROMPTS["entity_extraction_examples"][: int(example_number)]
)
else:
examples="\n".join(PROMPTS["entity_extraction_examples"])
examples = "\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
@@ -263,8 +267,9 @@ async def extract_entities(
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
examples=examples,
language=language)
language=language,
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -396,6 +401,7 @@ async def extract_entities(
return knowledge_graph_inst
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -408,59 +414,61 @@ async def kg_query(
context = None
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples="\n".join(PROMPTS["keywords_extraction_examples"])
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
result = await use_model_func(kw_prompt)
logger.info(f"kw_prompt result:")
kw_prompt = kw_prompt_temp.format(query=query, examples=examples)
result = await use_model_func(kw_prompt)
logger.info("kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
# Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
@@ -468,13 +476,13 @@ async def kg_query(
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -496,44 +504,72 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
ll_entities_context, ll_relations_context, ll_text_units_context = (
"",
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
query_param.mode = "global"
else:
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param
)
query_param,
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities_context, hl_relations_context, hl_text_units_context = (
"",
"",
"",
)
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
query_param.mode = "local"
else:
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param
)
if query_param.mode == 'hybrid':
entities_context,relations_context,text_units_context = combine_contexts(
[hl_entities_context,ll_entities_context],
[hl_relations_context,ll_relations_context],
[hl_text_units_context,ll_text_units_context]
)
elif query_param.mode == 'local':
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
elif query_param.mode == 'global':
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
query_param,
)
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
)
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
return f"""
# -----Entities-----
# ```csv
@@ -550,7 +586,6 @@ async def _build_query_context(
# """
async def _get_node_data(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -568,7 +603,7 @@ async def _get_node_data(
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
# 获取实体的度
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
@@ -588,7 +623,7 @@ async def _get_node_data(
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
)
)
# 构建提示词
entites_section_list = [["id", "entity", "type", "description", "rank"]]
@@ -625,7 +660,7 @@ async def _get_node_data(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return entities_context, relations_context, text_units_context
async def _find_most_related_text_unit_from_entities(
@@ -821,8 +856,7 @@ async def _get_edge_data(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return entities_context, relations_context, text_units_context
async def _find_most_related_entities_from_relationships(
@@ -902,7 +936,7 @@ async def _find_related_text_unit_from_relationships(
def combine_contexts(entities, relationships, sources):
# Function to extract entities, relationships, and sources from context strings
hl_entities, ll_entities = entities[0], entities[1]
hl_relationships, ll_relationships = relationships[0],relationships[1]
hl_relationships, ll_relationships = relationships[0], relationships[1]
hl_sources, ll_sources = sources[0], sources[1]
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)