fix: take global_config from storage class

This commit is contained in:
drahnreb
2025-04-17 16:57:53 +02:00
parent 0f949dd5d7
commit e71f466910
2 changed files with 11 additions and 26 deletions

View File

@@ -7,7 +7,7 @@ import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
from lightrag.kg import ( from lightrag.kg import (
STORAGES, STORAGES,

View File

@@ -116,7 +116,6 @@ async def _handle_entity_relation_summary(
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"] llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"] summary_max_tokens = global_config["summary_to_max_tokens"]
language = global_config["addon_params"].get( language = global_config["addon_params"].get(
@@ -842,7 +841,6 @@ async def kg_query(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
if query_param.only_need_context: if query_param.only_need_context:
@@ -1114,7 +1112,6 @@ async def mix_kg_vector_query(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
return context return context
@@ -1267,7 +1264,6 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str],
): ):
logger.info(f"Process {os.getpid()} buidling query context...") logger.info(f"Process {os.getpid()} buidling query context...")
if query_param.mode == "local": if query_param.mode == "local":
@@ -1277,7 +1273,6 @@ async def _build_query_context(
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
elif query_param.mode == "global": elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data( entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1286,7 +1281,6 @@ async def _build_query_context(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
else: # hybrid mode else: # hybrid mode
ll_data = await _get_node_data( ll_data = await _get_node_data(
@@ -1295,7 +1289,6 @@ async def _build_query_context(
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
hl_data = await _get_edge_data( hl_data = await _get_edge_data(
hl_keywords, hl_keywords,
@@ -1303,7 +1296,6 @@ async def _build_query_context(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
( (
@@ -1350,7 +1342,6 @@ async def _get_node_data(
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str],
): ):
# get similar entities # get similar entities
logger.info( logger.info(
@@ -1387,13 +1378,13 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk # get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config node_datas, query_param, text_chunks_db, knowledge_graph_inst,
) )
use_relations = await _find_most_related_edges_from_entities( use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst, global_config node_datas, query_param, knowledge_graph_inst,
) )
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
len_node_datas = len(node_datas) len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(
node_datas, node_datas,
@@ -1493,7 +1484,6 @@ async def _find_most_related_text_unit_from_entities(
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1575,7 +1565,7 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found") logger.warning("No valid text units found")
return [] return []
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
all_text_units = sorted( all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
) )
@@ -1598,7 +1588,6 @@ async def _find_most_related_edges_from_entities(
node_datas: list[dict], node_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
node_names = [dp["entity_name"] for dp in node_datas] node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1638,7 +1627,7 @@ async def _find_most_related_edges_from_entities(
} }
all_edges_data.append(combined) all_edges_data.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
all_edges_data = sorted( all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1662,7 +1651,6 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str],
): ):
logger.info( logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1703,7 +1691,7 @@ async def _get_edge_data(
} }
edge_datas.append(combined) edge_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1715,10 +1703,10 @@ async def _get_edge_data(
) )
use_entities, use_text_units = await asyncio.gather( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( _find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst, global_config edge_datas, query_param, knowledge_graph_inst,
), ),
_find_related_text_unit_from_relationships( _find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
), ),
) )
logger.info( logger.info(
@@ -1798,7 +1786,6 @@ async def _find_most_related_entities_from_relationships(
edge_datas: list[dict], edge_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
entity_names = [] entity_names = []
seen = set() seen = set()
@@ -1829,7 +1816,7 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree} combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined) node_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
len_node_datas = len(node_datas) len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(
node_datas, node_datas,
@@ -1849,7 +1836,6 @@ async def _find_related_text_unit_from_relationships(
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1891,7 +1877,7 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering") logger.warning("No valid text chunks after filtering")
return [] return []
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
truncated_text_units = truncate_list_by_token_size( truncated_text_units = truncate_list_by_token_size(
valid_text_units, valid_text_units,
key=lambda x: x["data"]["content"], key=lambda x: x["data"]["content"],
@@ -2128,7 +2114,6 @@ async def kg_query_with_keywords(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
if not context: if not context:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]