set kg by start param, defaults to networkx

This commit is contained in:
Ken Wiltshire
2024-11-01 08:47:52 -04:00
parent e4509327dd
commit e966a14418
7 changed files with 48 additions and 13 deletions

1
.gitignore vendored
View File

@@ -7,5 +7,4 @@ lightrag-dev/
dist/
env/
local_neo4jWorkDir/
local_neo4jWorkDir.bak/
neo4jWorkDir/

View File

@@ -1,5 +1,5 @@
print ("init package vars here. ......")
from .neo4j import GraphStorage as Neo4JStorage
# from .neo4j import GraphStorage as Neo4JStorage
# import sys

View File

@@ -25,7 +25,7 @@ from .storage import (
NetworkXStorage,
)
from .kg.neo4j import (
from .kg.neo4j_impl import (
GraphStorage as Neo4JStorage
)
@@ -58,10 +58,14 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@dataclass
class LightRAG:
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
kg: str = field(default="NetworkXStorage")
# text chunking
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
@@ -99,20 +103,15 @@ class LightRAG:
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
# module = importlib.import_module('kg.neo4j')
# Neo4JStorage = getattr(module, 'GraphStorage')
if True==True:
print ("using KG")
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
else:
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
enable_llm_cache: bool = True
# extension
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# def get_configured_KG(self):
# return self.kg
def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
@@ -121,6 +120,10 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
#should move all storage setup here to leverage initial start params attached to self.
print (f"self.kg set to: {self.kg}")
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
@@ -169,6 +172,12 @@ class LightRAG:
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
"Neo4JStorage": Neo4JStorage,
"NetworkXStorage": NetworkXStorage,
# "new_kg_here": KGClass
}
def insert(self, string_or_strings):
loop = always_get_an_event_loop()

View File

@@ -71,6 +71,7 @@ async def _handle_entity_relation_summary(
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
print ("Summarized: {context_base} for entity relationship {} ")
return summary
@@ -78,6 +79,7 @@ async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}")
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
@@ -263,6 +265,7 @@ async def extract_entities(
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
print (f"kw: processing a single chunk, {chunk_key_dp}")
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
@@ -432,6 +435,7 @@ async def local_query(
text_chunks_db,
query_param,
)
print (f"got the following context {context} based on prompt keywords {keywords}")
if query_param.only_need_context:
return context
if context is None:
@@ -440,6 +444,7 @@ async def local_query(
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
print (f"local query:{query} local sysprompt:{sys_prompt}")
response = await use_model_func(
query,
system_prompt=sys_prompt,
@@ -465,14 +470,20 @@ async def _build_local_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
print ("kw1: ENTITIES VDB QUERY**********************************")
results = await entities_vdb.query(query, top_k=query_param.top_k)
print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************")
if not len(results):
return None
print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords")
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
print ("kw4: getting node degrees next for the same entities/nodes")
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
@@ -480,7 +491,7 @@ async def _build_local_query_context(
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
@@ -718,6 +729,7 @@ async def _build_global_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
print ("RELATIONSHIPS VDB QUERY**********************************")
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
@@ -883,12 +895,14 @@ async def hybrid_query(
query_param: QueryParam,
global_config: dict,
) -> str:
print ("HYBRID QUERY *********")
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
print ( f"kw:kw_prompt: {kw_prompt}")
result = await use_model_func(kw_prompt)
try:
@@ -897,6 +911,8 @@ async def hybrid_query(
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
print (f"High level key words: {hl_keywords}")
print (f"Low level key words: {ll_keywords}")
except json.JSONDecodeError:
try:
result = (
@@ -926,6 +942,8 @@ async def hybrid_query(
query_param,
)
print (f"low_level_context: {low_level_context}")
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
@@ -935,6 +953,8 @@ async def hybrid_query(
text_chunks_db,
query_param,
)
print (f"high_level_context: {high_level_context}")
context = combine_contexts(high_level_context, low_level_context)
@@ -951,6 +971,7 @@ async def hybrid_query(
query,
system_prompt=sys_prompt,
)
print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}")
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -1044,10 +1065,13 @@ async def naive_query(
):
use_model_func = global_config["llm_model_func"]
results = await chunks_vdb.query(query, top_k=query_param.top_k)
print (f"raw chunks from chunks_vdb.query {results}")
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ")
maybe_trun_chunks = truncate_list_by_token_size(
chunks,

View File

@@ -95,6 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
print (f"Upserting to vector: {list_data}")
results = self._client.upsert(datas=list_data)
return results
@@ -109,6 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
print (f"vector db results {results} for query {query}")
return results
async def index_done_callback(self):

View File

@@ -15,7 +15,8 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage"
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)