Optimization logic

This commit is contained in:
jin
2024-11-25 13:29:55 +08:00
parent 662303f605
commit 89c2de54a2
10 changed files with 342 additions and 423 deletions

View File

@@ -248,14 +248,23 @@ async def extract_entities(
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
else:
examples="\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
)
examples=examples,
language=language)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -270,7 +279,6 @@ async def extract_entities(
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
@@ -388,8 +396,7 @@ async def extract_entities(
return knowledge_graph_inst
async def local_query(
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
@@ -399,43 +406,61 @@ async def local_query(
global_config: dict,
) -> str:
context = None
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
else:
examples="\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("local_query json_text:", json_text)
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
result = await use_model_func(kw_prompt)
logger.info(f"kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
print(result)
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if keywords:
context = await _build_local_query_context(
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
# Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
@@ -443,13 +468,13 @@ async def local_query(
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -464,22 +489,87 @@ async def local_query(
return response
async def _build_local_query_context(
async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
query_param.mode = "global"
else:
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
query_param.mode = "local"
else:
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param
)
if query_param.mode == 'hybrid':
entities_context,relations_context,text_units_context = combine_contexts(
[hl_entities_context,ll_entities_context],
[hl_relations_context,ll_relations_context],
[hl_text_units_context,ll_text_units_context]
)
elif query_param.mode == 'local':
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
elif query_param.mode == 'global':
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
return f"""
# -----Entities-----
# ```csv
# {entities_context}
# ```
# -----Relationships-----
# ```csv
# {relations_context}
# ```
# -----Sources-----
# ```csv
# {text_units_context}
# ```
# """
async def _get_node_data(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
# 获取相似的实体
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return None
# 获取实体信息
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
# 获取实体的度
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
@@ -488,15 +578,19 @@ async def _build_local_query_context(
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# 根据实体获取文本片段
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
# 获取关联的边
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
)
)
# 构建提示词
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
@@ -531,20 +625,7 @@ async def _build_local_query_context(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
return entities_context,relations_context,text_units_context
async def _find_most_related_text_unit_from_entities(
@@ -659,88 +740,9 @@ async def _find_most_related_edges_from_entities(
return all_edges_data
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("global json_text:", json_text)
try:
keywords_data = json.loads(json_text)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if keywords:
context = await _build_global_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
async def _get_edge_data(
keywords,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
@@ -782,6 +784,7 @@ async def _build_global_query_context(
logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
)
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
@@ -816,21 +819,8 @@ async def _build_global_query_context(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_entities_from_relationships(
@@ -901,137 +891,11 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("hybrid_query json_text:", json_text)
try:
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if ll_keywords:
low_level_context = await _build_local_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
def combine_contexts(entities, relationships, sources):
# Function to extract entities, relationships, and sources from context strings
def extract_sections(context):
entities_match = re.search(
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
hl_entities, ll_entities = entities[0], entities[1]
hl_relationships, ll_relationships = relationships[0],relationships[1]
hl_sources, ll_sources = sources[0], sources[1]
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)
@@ -1043,21 +907,7 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources)
# Format the combined context
return f"""
-----Entities-----
```csv
{combined_entities}
```
-----Relationships-----
```csv
{combined_relationships}
```
-----Sources-----
```csv
{combined_sources}
```
"""
return combined_entities, combined_relationships, combined_sources
async def naive_query(
@@ -1080,7 +930,7 @@ async def naive_query(
max_token_size=query_param.max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]