add citation
This commit is contained in:
@@ -138,6 +138,7 @@ async def _handle_entity_relation_summary(
|
||||
async def _handle_single_entity_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
file_path: str = "unknown_source",
|
||||
):
|
||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||
return None
|
||||
@@ -171,13 +172,14 @@ async def _handle_single_entity_extraction(
|
||||
entity_type=entity_type,
|
||||
description=entity_description,
|
||||
source_id=chunk_key,
|
||||
metadata={"created_at": time.time()},
|
||||
metadata={"created_at": time.time(), "file_path": file_path},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_single_relationship_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
file_path: str = "unknown_source",
|
||||
):
|
||||
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||||
return None
|
||||
@@ -199,7 +201,7 @@ async def _handle_single_relationship_extraction(
|
||||
description=edge_description,
|
||||
keywords=edge_keywords,
|
||||
source_id=edge_source_id,
|
||||
metadata={"created_at": time.time()},
|
||||
metadata={"created_at": time.time(), "file_path": file_path},
|
||||
)
|
||||
|
||||
|
||||
@@ -213,6 +215,7 @@ async def _merge_nodes_then_upsert(
|
||||
already_entity_types = []
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
already_file_paths = []
|
||||
|
||||
already_node = await knowledge_graph_inst.get_node(entity_name)
|
||||
if already_node is not None:
|
||||
@@ -220,6 +223,9 @@ async def _merge_nodes_then_upsert(
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_file_paths.extend(
|
||||
split_string_by_multi_markers(already_node["metadata"]["file_path"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_description.append(already_node["description"])
|
||||
|
||||
entity_type = sorted(
|
||||
@@ -235,6 +241,10 @@ async def _merge_nodes_then_upsert(
|
||||
source_id = GRAPH_FIELD_SEP.join(
|
||||
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
||||
)
|
||||
file_path = GRAPH_FIELD_SEP.join(
|
||||
set([dp["metadata"]["file_path"] for dp in nodes_data] + already_file_paths)
|
||||
)
|
||||
print(f"file_path: {file_path}")
|
||||
description = await _handle_entity_relation_summary(
|
||||
entity_name, description, global_config
|
||||
)
|
||||
@@ -243,6 +253,7 @@ async def _merge_nodes_then_upsert(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
)
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
entity_name,
|
||||
@@ -263,6 +274,7 @@ async def _merge_edges_then_upsert(
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
already_keywords = []
|
||||
already_file_paths = []
|
||||
|
||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||
@@ -278,6 +290,14 @@ async def _merge_edges_then_upsert(
|
||||
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
# Get file_path with empty string default if missing or None
|
||||
if already_edge.get("file_path") is not None:
|
||||
already_file_paths.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["metadata"]["file_path"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
# Get description with empty string default if missing or None
|
||||
if already_edge.get("description") is not None:
|
||||
@@ -315,6 +335,9 @@ async def _merge_edges_then_upsert(
|
||||
+ already_source_ids
|
||||
)
|
||||
)
|
||||
file_path = GRAPH_FIELD_SEP.join(
|
||||
set([dp["metadata"]["file_path"] for dp in edges_data if dp.get("metadata", {}).get("file_path")] + already_file_paths)
|
||||
)
|
||||
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
@@ -325,6 +348,7 @@ async def _merge_edges_then_upsert(
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
"file_path": file_path,
|
||||
},
|
||||
)
|
||||
description = await _handle_entity_relation_summary(
|
||||
@@ -338,6 +362,7 @@ async def _merge_edges_then_upsert(
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -347,6 +372,7 @@ async def _merge_edges_then_upsert(
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
return edge_data
|
||||
@@ -456,11 +482,12 @@ async def extract_entities(
|
||||
else:
|
||||
return await use_llm_func(input_text)
|
||||
|
||||
async def _process_extraction_result(result: str, chunk_key: str):
|
||||
async def _process_extraction_result(result: str, chunk_key: str, file_path: str = "unknown_source"):
|
||||
"""Process a single extraction result (either initial or gleaning)
|
||||
Args:
|
||||
result (str): The extraction result to process
|
||||
chunk_key (str): The chunk key for source tracking
|
||||
file_path (str): The file path for citation
|
||||
Returns:
|
||||
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
|
||||
"""
|
||||
@@ -482,14 +509,14 @@ async def extract_entities(
|
||||
)
|
||||
|
||||
if_entities = await _handle_single_entity_extraction(
|
||||
record_attributes, chunk_key
|
||||
record_attributes, chunk_key, file_path
|
||||
)
|
||||
if if_entities is not None:
|
||||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||||
continue
|
||||
|
||||
if_relation = await _handle_single_relationship_extraction(
|
||||
record_attributes, chunk_key
|
||||
record_attributes, chunk_key, file_path
|
||||
)
|
||||
if if_relation is not None:
|
||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||||
@@ -508,6 +535,8 @@ async def extract_entities(
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
# Get file path from chunk data or use default
|
||||
file_path = chunk_dp.get("file_path", "unknown_source")
|
||||
|
||||
# Get initial extraction
|
||||
hint_prompt = entity_extract_prompt.format(
|
||||
@@ -517,9 +546,9 @@ async def extract_entities(
|
||||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
|
||||
# Process initial extraction
|
||||
# Process initial extraction with file path
|
||||
maybe_nodes, maybe_edges = await _process_extraction_result(
|
||||
final_result, chunk_key
|
||||
final_result, chunk_key, file_path
|
||||
)
|
||||
|
||||
# Process additional gleaning results
|
||||
@@ -530,9 +559,9 @@ async def extract_entities(
|
||||
|
||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||
|
||||
# Process gleaning result separately
|
||||
# Process gleaning result separately with file path
|
||||
glean_nodes, glean_edges = await _process_extraction_result(
|
||||
glean_result, chunk_key
|
||||
glean_result, chunk_key, file_path
|
||||
)
|
||||
|
||||
# Merge results
|
||||
@@ -594,7 +623,7 @@ async def extract_entities(
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if not (all_entities_data or all_relationships_data):
|
||||
log_message = "Didn't extract any entities and relationships."
|
||||
logger.info(log_message)
|
||||
@@ -637,8 +666,10 @@ async def extract_entities(
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"),
|
||||
"metadata": {
|
||||
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
||||
"created_at": dp.get("metadata", {}).get("created_at", time.time()),
|
||||
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"),
|
||||
},
|
||||
}
|
||||
for dp in all_entities_data
|
||||
@@ -653,8 +684,10 @@ async def extract_entities(
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"),
|
||||
"metadata": {
|
||||
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
||||
"created_at": dp.get("metadata", {}).get("created_at", time.time()),
|
||||
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"),
|
||||
},
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
@@ -1232,12 +1265,20 @@ async def _get_node_data(
|
||||
"description",
|
||||
"rank",
|
||||
"created_at",
|
||||
"file_path",
|
||||
]
|
||||
]
|
||||
for i, n in enumerate(node_datas):
|
||||
created_at = n.get("created_at", "UNKNOWN")
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from node data
|
||||
file_path = n.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
entites_section_list.append(
|
||||
[
|
||||
i,
|
||||
@@ -1246,6 +1287,7 @@ async def _get_node_data(
|
||||
n.get("description", "UNKNOWN"),
|
||||
n["rank"],
|
||||
created_at,
|
||||
file_path,
|
||||
]
|
||||
)
|
||||
entities_context = list_of_list_to_csv(entites_section_list)
|
||||
@@ -1260,6 +1302,7 @@ async def _get_node_data(
|
||||
"weight",
|
||||
"rank",
|
||||
"created_at",
|
||||
"file_path",
|
||||
]
|
||||
]
|
||||
for i, e in enumerate(use_relations):
|
||||
@@ -1267,6 +1310,13 @@ async def _get_node_data(
|
||||
# Convert timestamp to readable format
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from edge data
|
||||
file_path = e.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
relations_section_list.append(
|
||||
[
|
||||
i,
|
||||
@@ -1277,6 +1327,7 @@ async def _get_node_data(
|
||||
e["weight"],
|
||||
e["rank"],
|
||||
created_at,
|
||||
file_path,
|
||||
]
|
||||
)
|
||||
relations_context = list_of_list_to_csv(relations_section_list)
|
||||
@@ -1492,6 +1543,7 @@ async def _get_edge_data(
|
||||
"weight",
|
||||
"rank",
|
||||
"created_at",
|
||||
"file_path",
|
||||
]
|
||||
]
|
||||
for i, e in enumerate(edge_datas):
|
||||
@@ -1499,6 +1551,13 @@ async def _get_edge_data(
|
||||
# Convert timestamp to readable format
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from edge data
|
||||
file_path = e.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
relations_section_list.append(
|
||||
[
|
||||
i,
|
||||
@@ -1509,16 +1568,34 @@ async def _get_edge_data(
|
||||
e["weight"],
|
||||
e["rank"],
|
||||
created_at,
|
||||
file_path,
|
||||
]
|
||||
)
|
||||
relations_context = list_of_list_to_csv(relations_section_list)
|
||||
|
||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||
entites_section_list = [
|
||||
[
|
||||
"id",
|
||||
"entity",
|
||||
"type",
|
||||
"description",
|
||||
"rank",
|
||||
"created_at",
|
||||
"file_path"
|
||||
]
|
||||
]
|
||||
for i, n in enumerate(use_entities):
|
||||
created_at = e.get("created_at", "Unknown")
|
||||
created_at = n.get("created_at", "Unknown")
|
||||
# Convert timestamp to readable format
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from node data
|
||||
file_path = n.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
entites_section_list.append(
|
||||
[
|
||||
i,
|
||||
@@ -1527,6 +1604,7 @@ async def _get_edge_data(
|
||||
n.get("description", "UNKNOWN"),
|
||||
n["rank"],
|
||||
created_at,
|
||||
file_path,
|
||||
]
|
||||
)
|
||||
entities_context = list_of_list_to_csv(entites_section_list)
|
||||
@@ -1882,13 +1960,14 @@ async def kg_query_with_keywords(
|
||||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||||
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
# 6. Generate response
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
# 清理响应内容
|
||||
# Clean up response content
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
|
Reference in New Issue
Block a user