chore: added pre-commit-hooks and ruff formatting for commit-hooks
This commit is contained in:
@@ -25,6 +25,7 @@ from .base import (
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
):
|
||||
@@ -45,6 +46,7 @@ def chunking_by_token_size(
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _handle_entity_relation_summary(
|
||||
entity_or_relation_name: str,
|
||||
description: str,
|
||||
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
|
||||
return edge_data
|
||||
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
@@ -352,7 +355,9 @@ async def extract_entities(
|
||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||
return None
|
||||
if not len(all_relationships_data):
|
||||
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
|
||||
logger.warning(
|
||||
"Didn't extract any relationships, maybe your LLM is not working"
|
||||
)
|
||||
return None
|
||||
|
||||
if entity_vdb is not None:
|
||||
@@ -370,7 +375,10 @@ async def extract_entities(
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
@@ -378,6 +386,7 @@ async def extract_entities(
|
||||
|
||||
return knwoledge_graph_inst
|
||||
|
||||
|
||||
async def local_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -393,19 +402,24 @@ async def local_query(
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
|
||||
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
keywords = ", ".join(keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
keywords = ", ".join(keywords)
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
@@ -430,11 +444,20 @@ async def local_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _build_local_query_context(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -516,6 +539,7 @@ async def _build_local_query_context(
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
||||
return all_text_units
|
||||
|
||||
|
||||
async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
|
||||
)
|
||||
return all_edges_data
|
||||
|
||||
|
||||
async def global_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -624,20 +650,25 @@ async def global_query(
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
|
||||
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
keywords = ", ".join(keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
|
||||
keywords = ", ".join(keywords)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle parsing error
|
||||
print(f"JSON parsing error: {e}")
|
||||
@@ -651,12 +682,12 @@ async def global_query(
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
@@ -665,11 +696,20 @@ async def global_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _build_global_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -679,14 +719,14 @@ async def _build_global_query_context(
|
||||
query_param: QueryParam,
|
||||
):
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
|
||||
if not len(results):
|
||||
return None
|
||||
|
||||
|
||||
edge_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
)
|
||||
|
||||
|
||||
if not all([n is not None for n in edge_datas]):
|
||||
logger.warning("Some edges are missing, maybe the storage is damaged")
|
||||
edge_degree = await asyncio.gather(
|
||||
@@ -765,6 +805,7 @@ async def _build_global_query_context(
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
|
||||
for e in edge_datas:
|
||||
entity_names.add(e["src_id"])
|
||||
entity_names.add(e["tgt_id"])
|
||||
|
||||
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
||||
)
|
||||
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
|
||||
|
||||
return node_datas
|
||||
|
||||
|
||||
async def _find_related_text_unit_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
for dp in edge_datas
|
||||
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
|
||||
"data": await text_chunks_db.get_by_id(c_id),
|
||||
"order": index,
|
||||
}
|
||||
|
||||
|
||||
if any([v is None for v in all_text_units_lookup.values()]):
|
||||
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
||||
all_text_units = [
|
||||
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
||||
]
|
||||
all_text_units = sorted(
|
||||
all_text_units, key=lambda x: x["order"]
|
||||
)
|
||||
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
||||
all_text_units = truncate_list_by_token_size(
|
||||
all_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
|
||||
|
||||
return all_text_units
|
||||
|
||||
|
||||
async def hybrid_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -849,24 +889,29 @@ async def hybrid_query(
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
|
||||
|
||||
result = await use_model_func(kw_prompt)
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ', '.join(hl_keywords)
|
||||
ll_keywords = ', '.join(ll_keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ', '.join(hl_keywords)
|
||||
ll_keywords = ', '.join(ll_keywords)
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
@@ -897,7 +942,7 @@ async def hybrid_query(
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
@@ -906,53 +951,78 @@ async def hybrid_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def combine_contexts(high_level_context, low_level_context):
|
||||
# Function to extract entities, relationships, and sources from context strings
|
||||
|
||||
def extract_sections(context):
|
||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
|
||||
entities = entities_match.group(1) if entities_match else ''
|
||||
relationships = relationships_match.group(1) if relationships_match else ''
|
||||
sources = sources_match.group(1) if sources_match else ''
|
||||
|
||||
entities_match = re.search(
|
||||
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
relationships_match = re.search(
|
||||
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
sources_match = re.search(
|
||||
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
|
||||
entities = entities_match.group(1) if entities_match else ""
|
||||
relationships = relationships_match.group(1) if relationships_match else ""
|
||||
sources = sources_match.group(1) if sources_match else ""
|
||||
|
||||
return entities, relationships, sources
|
||||
|
||||
|
||||
# Extract sections from both contexts
|
||||
|
||||
if high_level_context==None:
|
||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
||||
hl_entities, hl_relationships, hl_sources = '','',''
|
||||
if high_level_context is None:
|
||||
warnings.warn(
|
||||
"High Level context is None. Return empty High entity/relationship/source"
|
||||
)
|
||||
hl_entities, hl_relationships, hl_sources = "", "", ""
|
||||
else:
|
||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||
|
||||
|
||||
if low_level_context==None:
|
||||
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
||||
ll_entities, ll_relationships, ll_sources = '','',''
|
||||
if low_level_context is None:
|
||||
warnings.warn(
|
||||
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||
)
|
||||
ll_entities, ll_relationships, ll_sources = "", "", ""
|
||||
else:
|
||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||
|
||||
|
||||
|
||||
# Combine and deduplicate the entities
|
||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
||||
combined_entities = '\n'.join(combined_entities_set)
|
||||
|
||||
combined_entities_set = set(
|
||||
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
||||
)
|
||||
combined_entities = "\n".join(combined_entities_set)
|
||||
|
||||
# Combine and deduplicate the relationships
|
||||
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
||||
combined_relationships = '\n'.join(combined_relationships_set)
|
||||
|
||||
combined_relationships_set = set(
|
||||
filter(
|
||||
None,
|
||||
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
||||
)
|
||||
)
|
||||
combined_relationships = "\n".join(combined_relationships_set)
|
||||
|
||||
# Combine and deduplicate the sources
|
||||
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
||||
combined_sources = '\n'.join(combined_sources_set)
|
||||
|
||||
combined_sources_set = set(
|
||||
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
||||
)
|
||||
combined_sources = "\n".join(combined_sources_set)
|
||||
|
||||
# Format the combined context
|
||||
return f"""
|
||||
-----Entities-----
|
||||
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
|
||||
{combined_sources}
|
||||
"""
|
||||
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
@@ -996,8 +1067,16 @@ async def naive_query(
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
|
||||
return response
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response[len(sys_prompt) :]
|
||||
.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
Reference in New Issue
Block a user