fix format

This commit is contained in:
zrguo
2025-03-01 17:45:06 +08:00
parent 5bbe61a02d
commit 4219454fab
2 changed files with 56 additions and 35 deletions

View File

@@ -280,10 +280,7 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
""".format( """
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
result = await session.run(query) result = await session.run(query)
record = await result.single() record = await result.single()

View File

@@ -141,17 +141,18 @@ async def _handle_single_entity_extraction(
if len(record_attributes) < 4 or record_attributes[0] != '"entity"': if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None return None
# add this record as a node in the G # add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper()) entity_name = clean_str(record_attributes[1]).strip('"')
if not entity_name.strip(): if not entity_name.strip():
return None return None
entity_type = clean_str(record_attributes[2].upper()) entity_type = clean_str(record_attributes[2]).strip('"')
entity_description = clean_str(record_attributes[3]) entity_description = clean_str(record_attributes[3]).strip('"')
entity_source_id = chunk_key entity_source_id = chunk_key
return dict( return dict(
entity_name=entity_name, entity_name=entity_name,
entity_type=entity_type, entity_type=entity_type,
description=entity_description, description=entity_description,
source_id=entity_source_id, source_id=entity_source_id,
metadata={"created_at": time.time()},
) )
@@ -162,14 +163,15 @@ async def _handle_single_relationship_extraction(
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None return None
# add this record as edge # add this record as edge
source = clean_str(record_attributes[1].upper()) source = clean_str(record_attributes[1]).strip('"')
target = clean_str(record_attributes[2].upper()) target = clean_str(record_attributes[2]).strip('"')
edge_description = clean_str(record_attributes[3]) edge_description = clean_str(record_attributes[3]).strip('"')
edge_keywords = clean_str(record_attributes[4]).strip('"')
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key edge_source_id = chunk_key
weight = ( weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 float(record_attributes[-1].strip('"'))
if is_float_regex(record_attributes[-1])
else 1.0
) )
return dict( return dict(
src_id=source, src_id=source,
@@ -547,9 +549,13 @@ async def extract_entities(
if entity_vdb is not None: if entity_vdb is not None:
data_for_vdb = { data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): { compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"], "entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"], "source_id": dp["source_id"],
"metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time())
},
} }
for dp in all_entities_data for dp in all_entities_data
} }
@@ -560,11 +566,9 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"], "src_id": dp["src_id"],
"tgt_id": dp["tgt_id"], "tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"], "source_id": dp["source_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
"metadata": { "metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time()) "created_at": dp.get("metadata", {}).get("created_at", time.time())
}, },
@@ -960,7 +964,7 @@ async def mix_kg_vector_query(
stream=query_param.stream, stream=query_param.stream,
) )
# 清理响应内容 # Clean up response content
if isinstance(response, str) and len(response) > len(sys_prompt): if isinstance(response, str) and len(response) > len(sys_prompt):
response = ( response = (
response.replace(sys_prompt, "") response.replace(sys_prompt, "")
@@ -972,7 +976,7 @@ async def mix_kg_vector_query(
.strip() .strip()
) )
# 7. Save cache - 只有在收集完整响应后才缓存 # 7. Save cache - Only cache after collecting complete response
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
@@ -1128,8 +1132,19 @@ async def _get_node_data(
) )
# build prompt # build prompt
entites_section_list = [["id", "entity", "type", "description", "rank"]] entites_section_list = [
[
"id",
"entity",
"type",
"description",
"rank" "created_at",
]
]
for i, n in enumerate(node_datas): for i, n in enumerate(node_datas):
created_at = n.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
entites_section_list.append( entites_section_list.append(
[ [
i, i,
@@ -1137,6 +1152,7 @@ async def _get_node_data(
n.get("entity_type", "UNKNOWN"), n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"), n.get("description", "UNKNOWN"),
n["rank"], n["rank"],
created_at,
] ]
) )
entities_context = list_of_list_to_csv(entites_section_list) entities_context = list_of_list_to_csv(entites_section_list)
@@ -1401,6 +1417,10 @@ async def _get_edge_data(
entites_section_list = [["id", "entity", "type", "description", "rank"]] entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(use_entities): for i, n in enumerate(use_entities):
created_at = e.get("created_at", "Unknown")
# Convert timestamp to readable format
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
entites_section_list.append( entites_section_list.append(
[ [
i, i,
@@ -1408,6 +1428,7 @@ async def _get_edge_data(
n.get("entity_type", "UNKNOWN"), n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"), n.get("description", "UNKNOWN"),
n["rank"], n["rank"],
created_at,
] ]
) )
entities_context = list_of_list_to_csv(entites_section_list) entities_context = list_of_list_to_csv(entites_section_list)
@@ -1766,6 +1787,8 @@ async def kg_query_with_keywords(
system_prompt=sys_prompt, system_prompt=sys_prompt,
stream=query_param.stream, stream=query_param.stream,
) )
# 清理响应内容
if isinstance(response, str) and len(response) > len(sys_prompt): if isinstance(response, str) and len(response) > len(sys_prompt):
response = ( response = (
response.replace(sys_prompt, "") response.replace(sys_prompt, "")
@@ -1777,7 +1800,7 @@ async def kg_query_with_keywords(
.strip() .strip()
) )
# Save to cache # 7. Save cache - 只有在收集完整响应后才缓存
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
@@ -1791,4 +1814,5 @@ async def kg_query_with_keywords(
cache_type="query", cache_type="query",
), ),
) )
return response return response