fix: take global_config from storage class
This commit is contained in:
@@ -7,7 +7,7 @@ import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
|
||||
|
||||
from lightrag.kg import (
|
||||
STORAGES,
|
||||
|
@@ -116,7 +116,6 @@ async def _handle_entity_relation_summary(
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||
tiktoken_model_name = global_config["tiktoken_model_name"]
|
||||
summary_max_tokens = global_config["summary_to_max_tokens"]
|
||||
|
||||
language = global_config["addon_params"].get(
|
||||
@@ -842,7 +841,6 @@ async def kg_query(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
@@ -1114,7 +1112,6 @@ async def mix_kg_vector_query(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
|
||||
return context
|
||||
@@ -1267,7 +1264,6 @@ async def _build_query_context(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
logger.info(f"Process {os.getpid()} buidling query context...")
|
||||
if query_param.mode == "local":
|
||||
@@ -1277,7 +1273,6 @@ async def _build_query_context(
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
elif query_param.mode == "global":
|
||||
entities_context, relations_context, text_units_context = await _get_edge_data(
|
||||
@@ -1286,7 +1281,6 @@ async def _build_query_context(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
else: # hybrid mode
|
||||
ll_data = await _get_node_data(
|
||||
@@ -1295,7 +1289,6 @@ async def _build_query_context(
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
hl_data = await _get_edge_data(
|
||||
hl_keywords,
|
||||
@@ -1303,7 +1296,6 @@ async def _build_query_context(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
|
||||
(
|
||||
@@ -1350,7 +1342,6 @@ async def _get_node_data(
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
# get similar entities
|
||||
logger.info(
|
||||
@@ -1387,13 +1378,13 @@ async def _get_node_data(
|
||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||
# get entitytext chunk
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst,
|
||||
)
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst, global_config
|
||||
node_datas, query_param, knowledge_graph_inst,
|
||||
)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
node_datas,
|
||||
@@ -1493,7 +1484,6 @@ async def _find_most_related_text_unit_from_entities(
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -1575,7 +1565,7 @@ async def _find_most_related_text_unit_from_entities(
|
||||
logger.warning("No valid text units found")
|
||||
return []
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||
all_text_units = sorted(
|
||||
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
||||
)
|
||||
@@ -1598,7 +1588,6 @@ async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
node_names = [dp["entity_name"] for dp in node_datas]
|
||||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||||
@@ -1638,7 +1627,7 @@ async def _find_most_related_edges_from_entities(
|
||||
}
|
||||
all_edges_data.append(combined)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
|
||||
all_edges_data = sorted(
|
||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1662,7 +1651,6 @@ async def _get_edge_data(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
logger.info(
|
||||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||
@@ -1703,7 +1691,7 @@ async def _get_edge_data(
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1715,10 +1703,10 @@ async def _get_edge_data(
|
||||
)
|
||||
use_entities, use_text_units = await asyncio.gather(
|
||||
_find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst, global_config
|
||||
edge_datas, query_param, knowledge_graph_inst,
|
||||
),
|
||||
_find_related_text_unit_from_relationships(
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
@@ -1798,7 +1786,6 @@ async def _find_most_related_entities_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
entity_names = []
|
||||
seen = set()
|
||||
@@ -1829,7 +1816,7 @@ async def _find_most_related_entities_from_relationships(
|
||||
combined = {**node, "entity_name": entity_name, "rank": degree}
|
||||
node_datas.append(combined)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
node_datas,
|
||||
@@ -1849,7 +1836,6 @@ async def _find_related_text_unit_from_relationships(
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -1891,7 +1877,7 @@ async def _find_related_text_unit_from_relationships(
|
||||
logger.warning("No valid text chunks after filtering")
|
||||
return []
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||
truncated_text_units = truncate_list_by_token_size(
|
||||
valid_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
@@ -2128,7 +2114,6 @@ async def kg_query_with_keywords(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
if not context:
|
||||
return PROMPTS["fail_response"]
|
||||
|
Reference in New Issue
Block a user