Logic Optimization

This commit is contained in:
jin
2024-11-25 13:40:38 +08:00
parent bf5815be8f
commit 21f161390a
8 changed files with 185 additions and 136 deletions

View File

@@ -114,7 +114,9 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
@@ -256,7 +258,7 @@ class OracleKVStorage(BaseKVStorage):
item["__vector__"],
]
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, values)
if self.namespace == "full_docs":
for k, v in self._data.items():
@@ -266,7 +268,7 @@ class OracleKVStorage(BaseKVStorage):
)
values = [k, self._data[k]["content"], self.db.workspace]
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, values)
return left_data
async def index_done_callback(self):

View File

@@ -70,8 +70,8 @@ async def openai_complete_if_cache(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if r'\u' in content:
content = content.encode('utf-8').decode('unicode_escape')
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
print(content)
if hashing_kv is not None:
await hashing_kv.upsert(
@@ -542,7 +542,7 @@ async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
@@ -551,7 +551,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

View File

@@ -249,13 +249,17 @@ async def extract_entities(
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
examples = "\n".join(
PROMPTS["entity_extraction_examples"][: int(example_number)]
)
else:
examples="\n".join(PROMPTS["entity_extraction_examples"])
examples = "\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
@@ -263,8 +267,9 @@ async def extract_entities(
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
examples=examples,
language=language)
language=language,
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -396,6 +401,7 @@ async def extract_entities(
return knowledge_graph_inst
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -408,59 +414,61 @@ async def kg_query(
context = None
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples="\n".join(PROMPTS["keywords_extraction_examples"])
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
result = await use_model_func(kw_prompt)
logger.info(f"kw_prompt result:")
kw_prompt = kw_prompt_temp.format(query=query, examples=examples)
result = await use_model_func(kw_prompt)
logger.info("kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
# Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
@@ -468,13 +476,13 @@ async def kg_query(
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -496,44 +504,72 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
ll_entities_context, ll_relations_context, ll_text_units_context = (
"",
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
query_param.mode = "global"
else:
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param
)
query_param,
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities_context, hl_relations_context, hl_text_units_context = (
"",
"",
"",
)
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
query_param.mode = "local"
else:
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param
)
if query_param.mode == 'hybrid':
entities_context,relations_context,text_units_context = combine_contexts(
[hl_entities_context,ll_entities_context],
[hl_relations_context,ll_relations_context],
[hl_text_units_context,ll_text_units_context]
)
elif query_param.mode == 'local':
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
elif query_param.mode == 'global':
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
query_param,
)
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
)
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
return f"""
# -----Entities-----
# ```csv
@@ -550,7 +586,6 @@ async def _build_query_context(
# """
async def _get_node_data(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -568,7 +603,7 @@ async def _get_node_data(
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
# 获取实体的度
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
@@ -588,7 +623,7 @@ async def _get_node_data(
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
)
)
# 构建提示词
entites_section_list = [["id", "entity", "type", "description", "rank"]]
@@ -625,7 +660,7 @@ async def _get_node_data(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return entities_context, relations_context, text_units_context
async def _find_most_related_text_unit_from_entities(
@@ -821,8 +856,7 @@ async def _get_edge_data(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return entities_context, relations_context, text_units_context
async def _find_most_related_entities_from_relationships(
@@ -902,7 +936,7 @@ async def _find_related_text_unit_from_relationships(
def combine_contexts(entities, relationships, sources):
# Function to extract entities, relationships, and sources from context strings
hl_entities, ll_entities = entities[0], entities[1]
hl_relationships, ll_relationships = relationships[0],relationships[1]
hl_relationships, ll_relationships = relationships[0], relationships[1]
hl_sources, ll_sources = sources[0], sources[1]
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)

View File

@@ -52,7 +52,7 @@ Output:
"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
"""Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
@@ -77,7 +77,7 @@ Output:
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
#############################""",
"""Example 2:
"""Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
@@ -95,7 +95,7 @@ Output:
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
#############################""",
"""Example 3:
"""Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
@@ -121,10 +121,12 @@ Output:
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
#############################"""
#############################""",
]
PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
PROMPTS[
"summarize_entity_descriptions"
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
@@ -139,10 +141,14 @@ Description List: {description_list}
Output:
"""
PROMPTS["entiti_continue_extraction"] = """MANY entities were missed in the last extraction. Add them below using the same format:
PROMPTS[
"entiti_continue_extraction"
] = """MANY entities were missed in the last extraction. Add them below using the same format:
"""
PROMPTS["entiti_if_loop_extraction"] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
PROMPTS[
"entiti_if_loop_extraction"
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
"""
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
@@ -201,7 +207,7 @@ Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
"""Example 1:
Query: "How does international trade influence global economic stability?"
################
@@ -211,7 +217,7 @@ Output:
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}}
#############################""",
"""Example 2:
"""Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?"
################
@@ -220,8 +226,8 @@ Output:
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}}
#############################""",
"""Example 3:
#############################""",
"""Example 3:
Query: "What is the role of education in reducing poverty?"
################
@@ -230,8 +236,8 @@ Output:
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}}
#############################"""
]
#############################""",
]
PROMPTS["naive_rag_response"] = """---Role---

View File

@@ -56,7 +56,8 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str)
return maybe_json_str
except:
except Exception:
pass
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
@@ -64,9 +65,9 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
# .replace("model", "")
# .strip()
# )
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# json.loads(maybe_json_str)
return None