fix: take global_config from storage class

This commit is contained in:
drahnreb
2025-04-17 16:57:53 +02:00
parent 0f949dd5d7
commit e71f466910
2 changed files with 11 additions and 26 deletions

View File

@@ -116,7 +116,6 @@ async def _handle_entity_relation_summary(
use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"]
language = global_config["addon_params"].get(
@@ -842,7 +841,6 @@ async def kg_query(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
if query_param.only_need_context:
@@ -1114,7 +1112,6 @@ async def mix_kg_vector_query(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
return context
@@ -1267,7 +1264,6 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
logger.info(f"Process {os.getpid()} buidling query context...")
if query_param.mode == "local":
@@ -1277,7 +1273,6 @@ async def _build_query_context(
entities_vdb,
text_chunks_db,
query_param,
global_config,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1286,7 +1281,6 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
else: # hybrid mode
ll_data = await _get_node_data(
@@ -1295,7 +1289,6 @@ async def _build_query_context(
entities_vdb,
text_chunks_db,
query_param,
global_config,
)
hl_data = await _get_edge_data(
hl_keywords,
@@ -1303,7 +1296,6 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
(
@@ -1350,7 +1342,6 @@ async def _get_node_data(
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
# get similar entities
logger.info(
@@ -1387,13 +1378,13 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst, global_config
node_datas, query_param, knowledge_graph_inst,
)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
@@ -1493,7 +1484,6 @@ async def _find_most_related_text_unit_from_entities(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1575,7 +1565,7 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found")
return []
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
@@ -1598,7 +1588,6 @@ async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1638,7 +1627,7 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1662,7 +1651,6 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1703,7 +1691,7 @@ async def _get_edge_data(
}
edge_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1715,10 +1703,10 @@ async def _get_edge_data(
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst, global_config
edge_datas, query_param, knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
),
)
logger.info(
@@ -1798,7 +1786,6 @@ async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
entity_names = []
seen = set()
@@ -1829,7 +1816,7 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
@@ -1849,7 +1836,6 @@ async def _find_related_text_unit_from_relationships(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1891,7 +1877,7 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering")
return []
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
@@ -2128,7 +2114,6 @@ async def kg_query_with_keywords(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
if not context:
return PROMPTS["fail_response"]