update
This commit is contained in:
944
lightrag/operate.py
Normal file
944
lightrag/operate.py
Normal file
@@ -0,0 +1,944 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Union
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from .utils import (
|
||||
logger,
|
||||
clean_str,
|
||||
compute_mdhash_id,
|
||||
decode_tokens_by_tiktoken,
|
||||
encode_string_by_tiktoken,
|
||||
is_float_regex,
|
||||
list_of_list_to_csv,
|
||||
pack_user_ass_to_openai_messages,
|
||||
split_string_by_multi_markers,
|
||||
truncate_list_by_token_size,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
TextChunkSchema,
|
||||
QueryParam,
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
):
|
||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||
results = []
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
"content": chunk_content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
async def _handle_entity_relation_summary(
|
||||
entity_or_relation_name: str,
|
||||
description: str,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||
tiktoken_model_name = global_config["tiktoken_model_name"]
|
||||
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
||||
|
||||
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
||||
if len(tokens) < summary_max_tokens: # No need for summary
|
||||
return description
|
||||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||||
use_description = decode_tokens_by_tiktoken(
|
||||
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
||||
)
|
||||
context_base = dict(
|
||||
entity_name=entity_or_relation_name,
|
||||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||||
)
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||||
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
||||
return summary
|
||||
|
||||
|
||||
async def _handle_single_entity_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
):
|
||||
if record_attributes[0] != '"entity"' or len(record_attributes) < 4:
|
||||
return None
|
||||
# add this record as a node in the G
|
||||
entity_name = clean_str(record_attributes[1].upper())
|
||||
if not entity_name.strip():
|
||||
return None
|
||||
entity_type = clean_str(record_attributes[2].upper())
|
||||
entity_description = clean_str(record_attributes[3])
|
||||
entity_source_id = chunk_key
|
||||
return dict(
|
||||
entity_name=entity_name,
|
||||
entity_type=entity_type,
|
||||
description=entity_description,
|
||||
source_id=entity_source_id,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_single_relationship_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
):
|
||||
if record_attributes[0] != '"relationship"' or len(record_attributes) < 5:
|
||||
return None
|
||||
# add this record as edge
|
||||
source = clean_str(record_attributes[1].upper())
|
||||
target = clean_str(record_attributes[2].upper())
|
||||
edge_description = clean_str(record_attributes[3])
|
||||
|
||||
edge_keywords = clean_str(record_attributes[4])
|
||||
edge_source_id = chunk_key
|
||||
weight = (
|
||||
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
||||
)
|
||||
return dict(
|
||||
src_id=source,
|
||||
tgt_id=target,
|
||||
weight=weight,
|
||||
description=edge_description,
|
||||
keywords=edge_keywords,
|
||||
source_id=edge_source_id,
|
||||
)
|
||||
|
||||
|
||||
async def _merge_nodes_then_upsert(
|
||||
entity_name: str,
|
||||
nodes_data: list[dict],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
):
|
||||
already_entitiy_types = []
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
|
||||
already_node = await knwoledge_graph_inst.get_node(entity_name)
|
||||
if already_node is not None:
|
||||
already_entitiy_types.append(already_node["entity_type"])
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_description.append(already_node["description"])
|
||||
|
||||
entity_type = sorted(
|
||||
Counter(
|
||||
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
|
||||
).items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[0][0]
|
||||
description = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
||||
)
|
||||
source_id = GRAPH_FIELD_SEP.join(
|
||||
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
||||
)
|
||||
description = await _handle_entity_relation_summary(
|
||||
entity_name, description, global_config
|
||||
)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=source_id,
|
||||
)
|
||||
await knwoledge_graph_inst.upsert_node(
|
||||
entity_name,
|
||||
node_data=node_data,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
return node_data
|
||||
|
||||
|
||||
async def _merge_edges_then_upsert(
|
||||
src_id: str,
|
||||
tgt_id: str,
|
||||
edges_data: list[dict],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
):
|
||||
already_weights = []
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
already_keywords = []
|
||||
|
||||
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
|
||||
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
|
||||
already_weights.append(already_edge["weight"])
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_description.append(already_edge["description"])
|
||||
already_keywords.extend(
|
||||
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
|
||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||
description = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
||||
)
|
||||
keywords = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
||||
)
|
||||
source_id = GRAPH_FIELD_SEP.join(
|
||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
||||
)
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
|
||||
await knwoledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": '"UNKNOWN"',
|
||||
},
|
||||
)
|
||||
description = await _handle_entity_relation_summary(
|
||||
(src_id, tgt_id), description, global_config
|
||||
)
|
||||
await knwoledge_graph_inst.upsert_edge(
|
||||
src_id,
|
||||
tgt_id,
|
||||
edge_data=dict(
|
||||
weight=weight,
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
source_id=source_id,
|
||||
),
|
||||
)
|
||||
|
||||
edge_data = dict(
|
||||
src_id=src_id,
|
||||
tgt_id=tgt_id,
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
return edge_data
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
|
||||
entity_extract_prompt = PROMPTS["entity_extraction"]
|
||||
context_base = dict(
|
||||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||||
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
||||
)
|
||||
continue_prompt = PROMPTS["entiti_continue_extraction"]
|
||||
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
||||
|
||||
already_processed = 0
|
||||
already_entities = 0
|
||||
already_relations = 0
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
nonlocal already_processed, already_entities, already_relations
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
||||
final_result = await use_llm_func(hint_prompt)
|
||||
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
for now_glean_index in range(entity_extract_max_gleaning):
|
||||
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
||||
|
||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||
final_result += glean_result
|
||||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||||
break
|
||||
|
||||
if_loop_result: str = await use_llm_func(
|
||||
if_loop_prompt, history_messages=history
|
||||
)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
|
||||
records = split_string_by_multi_markers(
|
||||
final_result,
|
||||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||||
)
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
record = record.group(1)
|
||||
record_attributes = split_string_by_multi_markers(
|
||||
record, [context_base["tuple_delimiter"]]
|
||||
)
|
||||
if_entities = await _handle_single_entity_extraction(
|
||||
record_attributes, chunk_key
|
||||
)
|
||||
if if_entities is not None:
|
||||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||||
continue
|
||||
|
||||
if_relation = await _handle_single_relationship_extraction(
|
||||
record_attributes, chunk_key
|
||||
)
|
||||
if if_relation is not None:
|
||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||||
if_relation
|
||||
)
|
||||
already_processed += 1
|
||||
already_entities += len(maybe_nodes)
|
||||
already_relations += len(maybe_edges)
|
||||
now_ticks = PROMPTS["process_tickers"][
|
||||
already_processed % len(PROMPTS["process_tickers"])
|
||||
]
|
||||
print(
|
||||
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
|
||||
results = await asyncio.gather(
|
||||
*[_process_single_content(c) for c in ordered_chunks]
|
||||
)
|
||||
print() # clear the progress bar
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for m_nodes, m_edges in results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
)
|
||||
all_relationships_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
if not len(all_entities_data):
|
||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||
return None
|
||||
if not len(all_relationships_data):
|
||||
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
|
||||
return None
|
||||
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
if relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
return knwoledge_graph_inst
|
||||
|
||||
async def local_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle parsing error
|
||||
print(f"JSON parsing error: {e}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
context = await _build_local_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
return response
|
||||
|
||||
async def _build_local_query_context(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
if not len(results):
|
||||
return None
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
)
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k["entity_name"], "rank": d}
|
||||
for k, n, d in zip(results, node_datas, node_degrees)
|
||||
if n is not None
|
||||
]
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
)
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst
|
||||
)
|
||||
logger.info(
|
||||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
||||
)
|
||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||
for i, n in enumerate(node_datas):
|
||||
entites_section_list.append(
|
||||
[
|
||||
i,
|
||||
n["entity_name"],
|
||||
n.get("entity_type", "UNKNOWN"),
|
||||
n.get("description", "UNKNOWN"),
|
||||
n["rank"],
|
||||
]
|
||||
)
|
||||
entities_context = list_of_list_to_csv(entites_section_list)
|
||||
|
||||
relations_section_list = [
|
||||
["id", "source", "target", "description", "keywords", "weight", "rank"]
|
||||
]
|
||||
for i, e in enumerate(use_relations):
|
||||
relations_section_list.append(
|
||||
[
|
||||
i,
|
||||
e["src_tgt"][0],
|
||||
e["src_tgt"][1],
|
||||
e["description"],
|
||||
e["keywords"],
|
||||
e["weight"],
|
||||
e["rank"],
|
||||
]
|
||||
)
|
||||
relations_context = list_of_list_to_csv(relations_section_list)
|
||||
|
||||
text_units_section_list = [["id", "content"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
{entities_context}
|
||||
```
|
||||
-----Relationships-----
|
||||
```csv
|
||||
{relations_context}
|
||||
```
|
||||
-----Sources-----
|
||||
```csv
|
||||
{text_units_context}
|
||||
```
|
||||
"""
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
for dp in node_datas
|
||||
]
|
||||
edges = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||
)
|
||||
all_one_hop_nodes = set()
|
||||
for this_edges in edges:
|
||||
if not this_edges:
|
||||
continue
|
||||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
||||
)
|
||||
all_one_hop_text_units_lookup = {
|
||||
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
|
||||
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
|
||||
if v is not None
|
||||
}
|
||||
all_text_units_lookup = {}
|
||||
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
||||
for c_id in this_text_units:
|
||||
if c_id in all_text_units_lookup:
|
||||
continue
|
||||
relation_counts = 0
|
||||
for e in this_edges:
|
||||
if (
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
):
|
||||
relation_counts += 1
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": await text_chunks_db.get_by_id(c_id),
|
||||
"order": index,
|
||||
"relation_counts": relation_counts,
|
||||
}
|
||||
if any([v is None for v in all_text_units_lookup.values()]):
|
||||
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
||||
all_text_units = [
|
||||
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
||||
]
|
||||
all_text_units = sorted(
|
||||
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
||||
)
|
||||
all_text_units = truncate_list_by_token_size(
|
||||
all_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
)
|
||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
||||
return all_text_units
|
||||
|
||||
async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
all_related_edges = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||
)
|
||||
all_edges = set()
|
||||
for this_edges in all_related_edges:
|
||||
all_edges.update([tuple(sorted(e)) for e in this_edges])
|
||||
all_edges = list(all_edges)
|
||||
all_edges_pack = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
||||
)
|
||||
all_edges_degree = await asyncio.gather(
|
||||
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
||||
)
|
||||
all_edges_data = [
|
||||
{"src_tgt": k, "rank": d, **v}
|
||||
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
||||
if v is not None
|
||||
]
|
||||
all_edges_data = sorted(
|
||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
all_edges_data = truncate_list_by_token_size(
|
||||
all_edges_data,
|
||||
key=lambda x: x["description"],
|
||||
max_token_size=query_param.max_token_for_global_context,
|
||||
)
|
||||
return all_edges_data
|
||||
|
||||
async def global_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle parsing error
|
||||
print(f"JSON parsing error: {e}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
context = await _build_global_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
return response
|
||||
|
||||
async def _build_global_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
if not len(results):
|
||||
return None
|
||||
|
||||
edge_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
)
|
||||
|
||||
if not all([n is not None for n in edge_datas]):
|
||||
logger.warning("Some edges are missing, maybe the storage is damaged")
|
||||
edge_degree = await asyncio.gather(
|
||||
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
|
||||
)
|
||||
edge_datas = [
|
||||
{"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v}
|
||||
for k, v, d in zip(results, edge_datas, edge_degree)
|
||||
if v is not None
|
||||
]
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
edge_datas = truncate_list_by_token_size(
|
||||
edge_datas,
|
||||
key=lambda x: x["description"],
|
||||
max_token_size=query_param.max_token_for_global_context,
|
||||
)
|
||||
|
||||
use_entities = await _find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst
|
||||
)
|
||||
use_text_units = await _find_related_text_unit_from_relationships(
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
)
|
||||
logger.info(
|
||||
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
||||
)
|
||||
relations_section_list = [
|
||||
["id", "source", "target", "description", "keywords", "weight", "rank"]
|
||||
]
|
||||
for i, e in enumerate(edge_datas):
|
||||
relations_section_list.append(
|
||||
[
|
||||
i,
|
||||
e["src_id"],
|
||||
e["tgt_id"],
|
||||
e["description"],
|
||||
e["keywords"],
|
||||
e["weight"],
|
||||
e["rank"],
|
||||
]
|
||||
)
|
||||
relations_context = list_of_list_to_csv(relations_section_list)
|
||||
|
||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||
for i, n in enumerate(use_entities):
|
||||
entites_section_list.append(
|
||||
[
|
||||
i,
|
||||
n["entity_name"],
|
||||
n.get("entity_type", "UNKNOWN"),
|
||||
n.get("description", "UNKNOWN"),
|
||||
n["rank"],
|
||||
]
|
||||
)
|
||||
entities_context = list_of_list_to_csv(entites_section_list)
|
||||
|
||||
text_units_section_list = [["id", "content"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
{entities_context}
|
||||
```
|
||||
-----Relationships-----
|
||||
```csv
|
||||
{relations_context}
|
||||
```
|
||||
-----Sources-----
|
||||
```csv
|
||||
{text_units_context}
|
||||
```
|
||||
"""
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
entity_names = set()
|
||||
for e in edge_datas:
|
||||
entity_names.add(e["src_id"])
|
||||
entity_names.add(e["tgt_id"])
|
||||
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
||||
)
|
||||
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k, "rank": d}
|
||||
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
||||
]
|
||||
|
||||
node_datas = truncate_list_by_token_size(
|
||||
node_datas,
|
||||
key=lambda x: x["description"],
|
||||
max_token_size=query_param.max_token_for_local_context,
|
||||
)
|
||||
|
||||
return node_datas
|
||||
|
||||
async def _find_related_text_unit_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
for dp in edge_datas
|
||||
]
|
||||
|
||||
all_text_units_lookup = {}
|
||||
|
||||
for index, unit_list in enumerate(text_units):
|
||||
for c_id in unit_list:
|
||||
if c_id not in all_text_units_lookup:
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": await text_chunks_db.get_by_id(c_id),
|
||||
"order": index,
|
||||
}
|
||||
|
||||
if any([v is None for v in all_text_units_lookup.values()]):
|
||||
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
||||
all_text_units = [
|
||||
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
||||
]
|
||||
all_text_units = sorted(
|
||||
all_text_units, key=lambda x: x["order"]
|
||||
)
|
||||
all_text_units = truncate_list_by_token_size(
|
||||
all_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
)
|
||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
||||
|
||||
return all_text_units
|
||||
|
||||
async def hybird_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query)
|
||||
result = await use_model_func(kw_prompt)
|
||||
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ', '.join(hl_keywords)
|
||||
ll_keywords = ', '.join(ll_keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle parsing error
|
||||
print(f"JSON parsing error: {e}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
low_level_context = await _build_local_query_context(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
high_level_context = await _build_global_query_context(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
context = combine_contexts(high_level_context, low_level_context)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
return response
|
||||
|
||||
def combine_contexts(high_level_context, low_level_context):
|
||||
# Function to extract entities, relationships, and sources from context strings
|
||||
def extract_sections(context):
|
||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
|
||||
entities = entities_match.group(1) if entities_match else ''
|
||||
relationships = relationships_match.group(1) if relationships_match else ''
|
||||
sources = sources_match.group(1) if sources_match else ''
|
||||
|
||||
return entities, relationships, sources
|
||||
|
||||
# Extract sections from both contexts
|
||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||
|
||||
# Combine and deduplicate the entities
|
||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
||||
combined_entities = '\n'.join(combined_entities_set)
|
||||
|
||||
# Combine and deduplicate the relationships
|
||||
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
||||
combined_relationships = '\n'.join(combined_relationships_set)
|
||||
|
||||
# Combine and deduplicate the sources
|
||||
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
||||
combined_sources = '\n'.join(combined_sources_set)
|
||||
|
||||
# Format the combined context
|
||||
return f"""
|
||||
-----Entities-----
|
||||
```csv
|
||||
{combined_entities}
|
||||
-----Relationships-----
|
||||
{combined_relationships}
|
||||
-----Sources-----
|
||||
{combined_sources}
|
||||
"""
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||
if not len(results):
|
||||
return PROMPTS["fail_response"]
|
||||
chunks_ids = [r["id"] for r in results]
|
||||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||||
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
)
|
||||
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
|
||||
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
||||
if query_param.only_need_context:
|
||||
return section
|
||||
sys_prompt_temp = PROMPTS["naive_rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
content_data=section, response_type=query_param.response_type
|
||||
)
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
return response
|
||||
|
Reference in New Issue
Block a user