Optimization logic
This commit is contained in:
@@ -248,14 +248,23 @@ async def extract_entities(
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
|
||||
# add language and example number params to prompt
|
||||
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
|
||||
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
|
||||
else:
|
||||
examples="\n".join(PROMPTS["entity_extraction_examples"])
|
||||
|
||||
entity_extract_prompt = PROMPTS["entity_extraction"]
|
||||
context_base = dict(
|
||||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||||
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
||||
)
|
||||
examples=examples,
|
||||
language=language)
|
||||
|
||||
continue_prompt = PROMPTS["entiti_continue_extraction"]
|
||||
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
||||
|
||||
@@ -270,7 +279,6 @@ async def extract_entities(
|
||||
content = chunk_dp["content"]
|
||||
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
||||
final_result = await use_llm_func(hint_prompt)
|
||||
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
for now_glean_index in range(entity_extract_max_gleaning):
|
||||
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
||||
@@ -388,8 +396,7 @@ async def extract_entities(
|
||||
|
||||
return knowledge_graph_inst
|
||||
|
||||
|
||||
async def local_query(
|
||||
async def kg_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
@@ -399,43 +406,61 @@ async def local_query(
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
context = None
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
||||
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
|
||||
else:
|
||||
examples="\n".join(PROMPTS["keywords_extraction_examples"])
|
||||
|
||||
# Set mode
|
||||
if query_param.mode not in ["local", "global", "hybrid"]:
|
||||
logger.error(f"Unknown mode {query_param.mode} in kg_query")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# LLM generate keywords
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
json_text = locate_json_string_body_from_string(result)
|
||||
logger.debug("local_query json_text:", json_text)
|
||||
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
|
||||
result = await use_model_func(kw_prompt)
|
||||
logger.info(f"kw_prompt result:")
|
||||
print(result)
|
||||
try:
|
||||
json_text = locate_json_string_body_from_string(result)
|
||||
keywords_data = json.loads(json_text)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ", ".join(keywords)
|
||||
except json.JSONDecodeError:
|
||||
print(result)
|
||||
try:
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ", ".join(keywords)
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return PROMPTS["fail_response"]
|
||||
if keywords:
|
||||
context = await _build_local_query_context(
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e} {result}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# Handdle keywords missing
|
||||
if hl_keywords == [] and ll_keywords == []:
|
||||
logger.warning("low_level_keywords and high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
|
||||
logger.warning("low_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
|
||||
logger.warning("high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
|
||||
# Build context
|
||||
keywords = [ll_keywords, hl_keywords]
|
||||
context = await _build_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
@@ -443,13 +468,13 @@ async def local_query(
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
)
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
@@ -464,22 +489,87 @@ async def local_query(
|
||||
return response
|
||||
|
||||
|
||||
async def _build_local_query_context(
|
||||
async def _build_query_context(
|
||||
query: list,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
ll_kewwords, hl_keywrds = query[0], query[1]
|
||||
if query_param.mode in ["local", "hybrid"]:
|
||||
if ll_kewwords == "":
|
||||
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
|
||||
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
||||
query_param.mode = "global"
|
||||
else:
|
||||
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
|
||||
ll_kewwords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param
|
||||
)
|
||||
if query_param.mode in ["global", "hybrid"]:
|
||||
if hl_keywrds == "":
|
||||
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
|
||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
||||
query_param.mode = "local"
|
||||
else:
|
||||
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
|
||||
hl_keywrds,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param
|
||||
)
|
||||
if query_param.mode == 'hybrid':
|
||||
entities_context,relations_context,text_units_context = combine_contexts(
|
||||
[hl_entities_context,ll_entities_context],
|
||||
[hl_relations_context,ll_relations_context],
|
||||
[hl_text_units_context,ll_text_units_context]
|
||||
)
|
||||
elif query_param.mode == 'local':
|
||||
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
|
||||
elif query_param.mode == 'global':
|
||||
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
|
||||
return f"""
|
||||
# -----Entities-----
|
||||
# ```csv
|
||||
# {entities_context}
|
||||
# ```
|
||||
# -----Relationships-----
|
||||
# ```csv
|
||||
# {relations_context}
|
||||
# ```
|
||||
# -----Sources-----
|
||||
# ```csv
|
||||
# {text_units_context}
|
||||
# ```
|
||||
# """
|
||||
|
||||
|
||||
|
||||
async def _get_node_data(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
# 获取相似的实体
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
|
||||
if not len(results):
|
||||
return None
|
||||
# 获取实体信息
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
)
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
|
||||
# 获取实体的度
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
)
|
||||
@@ -488,15 +578,19 @@ async def _build_local_query_context(
|
||||
for k, n, d in zip(results, node_datas, node_degrees)
|
||||
if n is not None
|
||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||
# 根据实体获取文本片段
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
)
|
||||
# 获取关联的边
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst
|
||||
)
|
||||
logger.info(
|
||||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
||||
)
|
||||
)
|
||||
|
||||
# 构建提示词
|
||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||
for i, n in enumerate(node_datas):
|
||||
entites_section_list.append(
|
||||
@@ -531,20 +625,7 @@ async def _build_local_query_context(
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
{entities_context}
|
||||
```
|
||||
-----Relationships-----
|
||||
```csv
|
||||
{relations_context}
|
||||
```
|
||||
-----Sources-----
|
||||
```csv
|
||||
{text_units_context}
|
||||
```
|
||||
"""
|
||||
return entities_context,relations_context,text_units_context
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
@@ -659,88 +740,9 @@ async def _find_most_related_edges_from_entities(
|
||||
return all_edges_data
|
||||
|
||||
|
||||
async def global_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
context = None
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
json_text = locate_json_string_body_from_string(result)
|
||||
logger.debug("global json_text:", json_text)
|
||||
try:
|
||||
keywords_data = json.loads(json_text)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ", ".join(keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ", ".join(keywords)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle parsing error
|
||||
print(f"JSON parsing error: {e}")
|
||||
return PROMPTS["fail_response"]
|
||||
if keywords:
|
||||
context = await _build_global_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _build_global_query_context(
|
||||
async def _get_edge_data(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
@@ -782,6 +784,7 @@ async def _build_global_query_context(
|
||||
logger.info(
|
||||
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
||||
)
|
||||
|
||||
relations_section_list = [
|
||||
["id", "source", "target", "description", "keywords", "weight", "rank"]
|
||||
]
|
||||
@@ -816,21 +819,8 @@ async def _build_global_query_context(
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context,relations_context,text_units_context
|
||||
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
{entities_context}
|
||||
```
|
||||
-----Relationships-----
|
||||
```csv
|
||||
{relations_context}
|
||||
```
|
||||
-----Sources-----
|
||||
```csv
|
||||
{text_units_context}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
@@ -901,137 +891,11 @@ async def _find_related_text_unit_from_relationships(
|
||||
return all_text_units
|
||||
|
||||
|
||||
async def hybrid_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
low_level_context = None
|
||||
high_level_context = None
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
|
||||
result = await use_model_func(kw_prompt)
|
||||
json_text = locate_json_string_body_from_string(result)
|
||||
logger.debug("hybrid_query json_text:", json_text)
|
||||
try:
|
||||
keywords_data = json.loads(json_text)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
if ll_keywords:
|
||||
low_level_context = await _build_local_query_context(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
if hl_keywords:
|
||||
high_level_context = await _build_global_query_context(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
context = combine_contexts(high_level_context, low_level_context)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def combine_contexts(high_level_context, low_level_context):
|
||||
def combine_contexts(entities, relationships, sources):
|
||||
# Function to extract entities, relationships, and sources from context strings
|
||||
|
||||
def extract_sections(context):
|
||||
entities_match = re.search(
|
||||
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
relationships_match = re.search(
|
||||
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
sources_match = re.search(
|
||||
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
|
||||
entities = entities_match.group(1) if entities_match else ""
|
||||
relationships = relationships_match.group(1) if relationships_match else ""
|
||||
sources = sources_match.group(1) if sources_match else ""
|
||||
|
||||
return entities, relationships, sources
|
||||
|
||||
# Extract sections from both contexts
|
||||
|
||||
if high_level_context is None:
|
||||
warnings.warn(
|
||||
"High Level context is None. Return empty High entity/relationship/source"
|
||||
)
|
||||
hl_entities, hl_relationships, hl_sources = "", "", ""
|
||||
else:
|
||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||
|
||||
if low_level_context is None:
|
||||
warnings.warn(
|
||||
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||
)
|
||||
ll_entities, ll_relationships, ll_sources = "", "", ""
|
||||
else:
|
||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||
|
||||
hl_entities, ll_entities = entities[0], entities[1]
|
||||
hl_relationships, ll_relationships = relationships[0],relationships[1]
|
||||
hl_sources, ll_sources = sources[0], sources[1]
|
||||
# Combine and deduplicate the entities
|
||||
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
||||
|
||||
@@ -1043,21 +907,7 @@ def combine_contexts(high_level_context, low_level_context):
|
||||
# Combine and deduplicate the sources
|
||||
combined_sources = process_combine_contexts(hl_sources, ll_sources)
|
||||
|
||||
# Format the combined context
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
{combined_entities}
|
||||
```
|
||||
-----Relationships-----
|
||||
```csv
|
||||
{combined_relationships}
|
||||
```
|
||||
-----Sources-----
|
||||
```csv
|
||||
{combined_sources}
|
||||
```
|
||||
"""
|
||||
return combined_entities, combined_relationships, combined_sources
|
||||
|
||||
|
||||
async def naive_query(
|
||||
@@ -1080,7 +930,7 @@ async def naive_query(
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
)
|
||||
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
|
||||
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
||||
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
||||
if query_param.only_need_context:
|
||||
return section
|
||||
sys_prompt_temp = PROMPTS["naive_rag_response"]
|
||||
|
Reference in New Issue
Block a user