chore: added pre-commit-hooks and ruff formatting for commit-hooks

This commit is contained in:
Sanketh Kumar
2024-10-19 09:43:17 +05:30
parent b854ab4737
commit 744dad339d
26 changed files with 635 additions and 393 deletions

View File

@@ -25,6 +25,7 @@ from .base import (
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
)
return results
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
description=description,
keywords=keywords,
)
return edge_data
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if not len(all_relationships_data):
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
logger.warning(
"Didn't extract any relationships, maybe your LLM is not working"
)
return None
if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
@@ -378,6 +386,7 @@ async def extract_entities(
return knwoledge_graph_inst
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -393,19 +402,24 @@ async def local_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords)
except json.JSONDecodeError as e:
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords)
keywords = ", ".join(keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -430,11 +444,20 @@ async def local_query(
query,
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_local_query_context(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -516,6 +539,7 @@ async def _build_local_query_context(
```
"""
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
)
return all_edges_data
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -624,20 +650,25 @@ async def global_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
except json.JSONDecodeError as e:
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
@@ -651,12 +682,12 @@ async def global_query(
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -665,11 +696,20 @@ async def global_query(
query,
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
keywords,
knowledge_graph_inst: BaseGraphStorage,
@@ -679,14 +719,14 @@ async def _build_global_query_context(
query_param: QueryParam,
):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
return None
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
)
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
@@ -765,6 +805,7 @@ async def _build_global_query_context(
```
"""
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
for e in edge_datas:
entity_names.add(e["src_id"])
entity_names.add(e["tgt_id"])
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
return node_datas
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
}
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
all_text_units = sorted(
all_text_units, key=lambda x: x["order"]
)
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -849,24 +889,29 @@ async def hybrid_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords)
ll_keywords = ', '.join(ll_keywords)
except json.JSONDecodeError as e:
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords)
ll_keywords = ', '.join(ll_keywords)
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -897,7 +942,7 @@ async def hybrid_query(
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -906,53 +951,78 @@ async def hybrid_query(
query,
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings
def extract_sections(context):
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
entities = entities_match.group(1) if entities_match else ''
relationships = relationships_match.group(1) if relationships_match else ''
sources = sources_match.group(1) if sources_match else ''
entities_match = re.search(
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context==None:
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities, hl_relationships, hl_sources = '','',''
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context==None:
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
ll_entities, ll_relationships, ll_sources = '','',''
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
combined_entities = '\n'.join(combined_entities_set)
combined_entities_set = set(
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
)
combined_entities = "\n".join(combined_entities_set)
# Combine and deduplicate the relationships
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
combined_relationships = '\n'.join(combined_relationships_set)
combined_relationships_set = set(
filter(
None,
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
)
)
combined_relationships = "\n".join(combined_relationships_set)
# Combine and deduplicate the sources
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
combined_sources = '\n'.join(combined_sources_set)
combined_sources_set = set(
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
)
combined_sources = "\n".join(combined_sources_set)
# Format the combined context
return f"""
-----Entities-----
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
{combined_sources}
"""
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
@@ -996,8 +1067,16 @@ async def naive_query(
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
return response
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response