bug fix issue #95

This commit is contained in:
benx13
2024-11-05 18:36:59 -08:00
parent b4ab9b26b8
commit 6f77f54c6d
2 changed files with 40 additions and 23 deletions

View File

@@ -55,13 +55,12 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
return asyncio.get_event_loop()
except RuntimeError:
logger.info("Creating a new event loop in main thread.")
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
return loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@dataclass

View File

@@ -561,46 +561,64 @@ async def _find_most_related_text_unit_from_entities(
if not this_edges:
continue
all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
)
# Add null check for node data
all_one_hop_text_units_lookup = {
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
if v is not None
if v is not None and "source_id" in v # Add source_id check
}
all_text_units_lookup = {}
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
if c_id in all_text_units_lookup:
continue
relation_counts = 0
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
relation_counts += 1
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
"relation_counts": relation_counts,
}
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
if this_edges: # Add check for None edges
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
relation_counts += 1
chunk_data = await text_chunks_db.get_by_id(c_id)
if chunk_data is not None and "content" in chunk_data: # Add content check
all_text_units_lookup[c_id] = {
"data": chunk_data,
"order": index,
"relation_counts": relation_counts,
}
# Filter out None values and ensure data has content
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
{"id": k, **v}
for k, v in all_text_units_lookup.items()
if v is not None and v.get("data") is not None and "content" in v["data"]
]
if not all_text_units:
logger.warning("No valid text units found")
return []
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
all_text_units,
key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
all_text_units = [t["data"] for t in all_text_units]
return all_text_units
@@ -1028,7 +1046,7 @@ def combine_contexts(high_level_context, low_level_context):
-----Sources-----
```csv
{combined_sources}
``
```
"""