From c956e39621e9aa696827598afb2f9665eabe90a6 Mon Sep 17 00:00:00 2001 From: benx13 Date: Tue, 5 Nov 2024 18:36:59 -0800 Subject: [PATCH] bug fix issue #95 --- lightrag/lightrag.py | 9 ++++---- lightrag/operate.py | 54 +++++++++++++++++++++++++++++--------------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5d271860..9cce75ba 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -55,13 +55,12 @@ from .base import ( def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: - loop = asyncio.get_running_loop() + return asyncio.get_event_loop() except RuntimeError: logger.info("Creating a new event loop in main thread.") - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - loop = asyncio.get_event_loop() - return loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop @dataclass diff --git a/lightrag/operate.py b/lightrag/operate.py index 2edeb548..23138ff9 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -561,46 +561,64 @@ async def _find_most_related_text_unit_from_entities( if not this_edges: continue all_one_hop_nodes.update([e[1] for e in this_edges]) + all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes_data = await asyncio.gather( *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] ) + + # Add null check for node data all_one_hop_text_units_lookup = { k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) - if v is not None + if v is not None and "source_id" in v # Add source_id check } + all_text_units_lookup = {} for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id in all_text_units_lookup: continue relation_counts = 0 - for e in this_edges: - if ( - e[1] in all_one_hop_text_units_lookup - and c_id in all_one_hop_text_units_lookup[e[1]] - ): - relation_counts += 1 - all_text_units_lookup[c_id] = { - "data": await text_chunks_db.get_by_id(c_id), - "order": index, - "relation_counts": relation_counts, - } - if any([v is None for v in all_text_units_lookup.values()]): - logger.warning("Text chunks are missing, maybe the storage is damaged") + if this_edges: # Add check for None edges + for e in this_edges: + if ( + e[1] in all_one_hop_text_units_lookup + and c_id in all_one_hop_text_units_lookup[e[1]] + ): + relation_counts += 1 + + chunk_data = await text_chunks_db.get_by_id(c_id) + if chunk_data is not None and "content" in chunk_data: # Add content check + all_text_units_lookup[c_id] = { + "data": chunk_data, + "order": index, + "relation_counts": relation_counts, + } + + # Filter out None values and ensure data has content all_text_units = [ - {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None + {"id": k, **v} + for k, v in all_text_units_lookup.items() + if v is not None and v.get("data") is not None and "content" in v["data"] ] + + if not all_text_units: + logger.warning("No valid text units found") + return [] + all_text_units = sorted( - all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) + all_text_units, + key=lambda x: (x["order"], -x["relation_counts"]) ) + all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) - all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] + + all_text_units = [t["data"] for t in all_text_units] return all_text_units @@ -1028,7 +1046,7 @@ def combine_contexts(high_level_context, low_level_context): -----Sources----- ```csv {combined_sources} -`` +``` """