fix: take global_config from storage class

This commit is contained in:
drahnreb
2025-04-17 16:57:53 +02:00
parent 0f949dd5d7
commit e71f466910
2 changed files with 11 additions and 26 deletions

View File

@@ -7,7 +7,7 @@ import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
from lightrag.kg import (
STORAGES,

View File

@@ -116,7 +116,6 @@ async def _handle_entity_relation_summary(
use_llm_func: callable = global_config["llm_model_func"]
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["summary_to_max_tokens"]
language = global_config["addon_params"].get(
@@ -842,7 +841,6 @@ async def kg_query(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
if query_param.only_need_context:
@@ -1114,7 +1112,6 @@ async def mix_kg_vector_query(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
return context
@@ -1267,7 +1264,6 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
logger.info(f"Process {os.getpid()} buidling query context...")
if query_param.mode == "local":
@@ -1277,7 +1273,6 @@ async def _build_query_context(
entities_vdb,
text_chunks_db,
query_param,
global_config,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1286,7 +1281,6 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
else: # hybrid mode
ll_data = await _get_node_data(
@@ -1295,7 +1289,6 @@ async def _build_query_context(
entities_vdb,
text_chunks_db,
query_param,
global_config,
)
hl_data = await _get_edge_data(
hl_keywords,
@@ -1303,7 +1296,6 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
(
@@ -1350,7 +1342,6 @@ async def _get_node_data(
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
# get similar entities
logger.info(
@@ -1387,13 +1378,13 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst, global_config
node_datas, query_param, knowledge_graph_inst,
)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
@@ -1493,7 +1484,6 @@ async def _find_most_related_text_unit_from_entities(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1575,7 +1565,7 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found")
return []
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
@@ -1598,7 +1588,6 @@ async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1638,7 +1627,7 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1662,7 +1651,6 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1703,7 +1691,7 @@ async def _get_edge_data(
}
edge_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1715,10 +1703,10 @@ async def _get_edge_data(
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst, global_config
edge_datas, query_param, knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
),
)
logger.info(
@@ -1798,7 +1786,6 @@ async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
entity_names = []
seen = set()
@@ -1829,7 +1816,7 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
@@ -1849,7 +1836,6 @@ async def _find_related_text_unit_from_relationships(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1891,7 +1877,7 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering")
return []
tokenizer: Tokenizer = global_config["tokenizer"]
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
@@ -2128,7 +2114,6 @@ async def kg_query_with_keywords(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
if not context:
return PROMPTS["fail_response"]