Merge branch 'main' of https://github.com/jin38324/LightRAG
This commit is contained in:
@@ -662,24 +662,20 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_text_units_lookup = {}
|
||||
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
||||
for c_id in this_text_units:
|
||||
if c_id in all_text_units_lookup:
|
||||
continue
|
||||
relation_counts = 0
|
||||
if this_edges: # Add check for None edges
|
||||
if c_id not in all_text_units_lookup:
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": await text_chunks_db.get_by_id(c_id),
|
||||
"order": index,
|
||||
"relation_counts": 0,
|
||||
}
|
||||
|
||||
if this_edges:
|
||||
for e in this_edges:
|
||||
if (
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
):
|
||||
relation_counts += 1
|
||||
|
||||
chunk_data = await text_chunks_db.get_by_id(c_id)
|
||||
if chunk_data is not None and "content" in chunk_data: # Add content check
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": chunk_data,
|
||||
"order": index,
|
||||
"relation_counts": relation_counts,
|
||||
}
|
||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||
|
||||
# Filter out None values and ensure data has content
|
||||
all_text_units = [
|
||||
@@ -714,10 +710,16 @@ async def _find_most_related_edges_from_entities(
|
||||
all_related_edges = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||
)
|
||||
all_edges = set()
|
||||
all_edges = []
|
||||
seen = set()
|
||||
|
||||
for this_edges in all_related_edges:
|
||||
all_edges.update([tuple(sorted(e)) for e in this_edges])
|
||||
all_edges = list(all_edges)
|
||||
for e in this_edges:
|
||||
sorted_edge = tuple(sorted(e))
|
||||
if sorted_edge not in seen:
|
||||
seen.add(sorted_edge)
|
||||
all_edges.append(sorted_edge)
|
||||
|
||||
all_edges_pack = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
||||
)
|
||||
@@ -828,10 +830,16 @@ async def _find_most_related_entities_from_relationships(
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
entity_names = set()
|
||||
entity_names = []
|
||||
seen = set()
|
||||
|
||||
for e in edge_datas:
|
||||
entity_names.add(e["src_id"])
|
||||
entity_names.add(e["tgt_id"])
|
||||
if e["src_id"] not in seen:
|
||||
entity_names.append(e["src_id"])
|
||||
seen.add(e["src_id"])
|
||||
if e["tgt_id"] not in seen:
|
||||
entity_names.append(e["tgt_id"])
|
||||
seen.add(e["tgt_id"])
|
||||
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
||||
|
Reference in New Issue
Block a user