From f7b66d2c222816676dd7882be7846eedb91d585d Mon Sep 17 00:00:00 2001 From: Dmytro Til Date: Fri, 24 Jan 2025 16:06:04 +0100 Subject: [PATCH] asyncio optimizations --- lightrag/operate.py | 173 +++++++++++++++++++++++++++----------------- 1 file changed, 106 insertions(+), 67 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index e1406904..94409c50 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -941,28 +941,35 @@ async def _build_query_context( query_param, ) else: # hybrid mode + ll_data, hl_data = await asyncio.gather( + _get_node_data( + ll_keywords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param, + ), + _get_edge_data( + hl_keywords, + knowledge_graph_inst, + relationships_vdb, + text_chunks_db, + query_param, + ), + ) + ( ll_entities_context, ll_relations_context, ll_text_units_context, - ) = await _get_node_data( - ll_keywords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, - ) + ) = ll_data + ( hl_entities_context, hl_relations_context, hl_text_units_context, - ) = await _get_edge_data( - hl_keywords, - knowledge_graph_inst, - relationships_vdb, - text_chunks_db, - query_param, - ) + ) = hl_data + entities_context, relations_context, text_units_context = combine_contexts( [hl_entities_context, ll_entities_context], [hl_relations_context, ll_relations_context], @@ -996,28 +1003,31 @@ async def _get_node_data( if not len(results): return "", "", "" # get entity information - node_datas = await asyncio.gather( - *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] + node_datas, node_degrees = await asyncio.gather( + asyncio.gather( + *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] + ), + asyncio.gather( + *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] + ), ) + if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") - # get entity degree - node_degrees = await asyncio.gather( - *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] - ) node_datas = [ {**n, "entity_name": k["entity_name"], "rank": d} for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. # get entitytext chunk - use_text_units = await _find_most_related_text_unit_from_entities( - node_datas, query_param, text_chunks_db, knowledge_graph_inst - ) - # get relate edges - use_relations = await _find_most_related_edges_from_entities( - node_datas, query_param, knowledge_graph_inst + use_text_units, use_relations = await asyncio.gather( + _find_most_related_text_unit_from_entities( + node_datas, query_param, text_chunks_db, knowledge_graph_inst + ), + _find_most_related_edges_from_entities( + node_datas, query_param, knowledge_graph_inst + ), ) logger.info( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" @@ -1107,22 +1117,30 @@ async def _find_most_related_text_unit_from_entities( } all_text_units_lookup = {} + tasks = [] for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id not in all_text_units_lookup: - all_text_units_lookup[c_id] = { - "data": await text_chunks_db.get_by_id(c_id), - "order": index, - "relation_counts": 0, - } + tasks.append((c_id, index, this_edges)) - if this_edges: - for e in this_edges: - if ( - e[1] in all_one_hop_text_units_lookup - and c_id in all_one_hop_text_units_lookup[e[1]] - ): - all_text_units_lookup[c_id]["relation_counts"] += 1 + results = await asyncio.gather( + *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks] + ) + + for (c_id, index, this_edges), data in zip(tasks, results): + all_text_units_lookup[c_id] = { + "data": data, + "order": index, + "relation_counts": 0, + } + + if this_edges: + for e in this_edges: + if ( + e[1] in all_one_hop_text_units_lookup + and c_id in all_one_hop_text_units_lookup[e[1]] + ): + all_text_units_lookup[c_id]["relation_counts"] += 1 # Filter out None values and ensure data has content all_text_units = [ @@ -1167,11 +1185,11 @@ async def _find_most_related_edges_from_entities( seen.add(sorted_edge) all_edges.append(sorted_edge) - all_edges_pack = await asyncio.gather( - *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges] - ) - all_edges_degree = await asyncio.gather( - *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] + all_edges_pack, all_edges_degree = await asyncio.gather( + asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]), + asyncio.gather( + *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] + ), ) all_edges_data = [ {"src_tgt": k, "rank": d, **v} @@ -1201,15 +1219,21 @@ async def _get_edge_data( if not len(results): return "", "", "" - edge_datas = await asyncio.gather( - *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] + edge_datas, edge_degree = await asyncio.gather( + asyncio.gather( + *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] + ), + asyncio.gather( + *[ + knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) + for r in results + ] + ), ) if not all([n is not None for n in edge_datas]): logger.warning("Some edges are missing, maybe the storage is damaged") - edge_degree = await asyncio.gather( - *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results] - ) + edge_datas = [ { "src_id": k["src_id"], @@ -1230,11 +1254,13 @@ async def _get_edge_data( max_token_size=query_param.max_token_for_global_context, ) - use_entities = await _find_most_related_entities_from_relationships( - edge_datas, query_param, knowledge_graph_inst - ) - use_text_units = await _find_related_text_unit_from_relationships( - edge_datas, query_param, text_chunks_db, knowledge_graph_inst + use_entities, use_text_units = await asyncio.gather( + _find_most_related_entities_from_relationships( + edge_datas, query_param, knowledge_graph_inst + ), + _find_related_text_unit_from_relationships( + edge_datas, query_param, text_chunks_db, knowledge_graph_inst + ), ) logger.info( f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" @@ -1307,12 +1333,19 @@ async def _find_most_related_entities_from_relationships( entity_names.append(e["tgt_id"]) seen.add(e["tgt_id"]) - node_datas = await asyncio.gather( - *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] - ) - - node_degrees = await asyncio.gather( - *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names] + node_datas, node_degrees = await asyncio.gather( + asyncio.gather( + *[ + knowledge_graph_inst.get_node(entity_name) + for entity_name in entity_names + ] + ), + asyncio.gather( + *[ + knowledge_graph_inst.node_degree(entity_name) + for entity_name in entity_names + ] + ), ) node_datas = [ {**n, "entity_name": k, "rank": d} @@ -1340,16 +1373,22 @@ async def _find_related_text_unit_from_relationships( ] all_text_units_lookup = {} + async def fetch_chunk_data(c_id, index): + if c_id not in all_text_units_lookup: + chunk_data = await text_chunks_db.get_by_id(c_id) + # Only store valid data + if chunk_data is not None and "content" in chunk_data: + all_text_units_lookup[c_id] = { + "data": chunk_data, + "order": index, + } + + tasks = [] for index, unit_list in enumerate(text_units): for c_id in unit_list: - if c_id not in all_text_units_lookup: - chunk_data = await text_chunks_db.get_by_id(c_id) - # Only store valid data - if chunk_data is not None and "content" in chunk_data: - all_text_units_lookup[c_id] = { - "data": chunk_data, - "order": index, - } + tasks.append(fetch_chunk_data(c_id, index)) + + await asyncio.gather(*tasks) if not all_text_units_lookup: logger.warning("No valid text chunks found")