diff --git a/.gitignore b/.gitignore index 7a161f3c..e04f6472 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,4 @@ lightrag-dev/ dist/ env/ local_neo4jWorkDir/ -local_neo4jWorkDir.bak/ neo4jWorkDir/ \ No newline at end of file diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index ddd2bb79..f9e28648 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -1,5 +1,5 @@ print ("init package vars here. ......") -from .neo4j import GraphStorage as Neo4JStorage +# from .neo4j import GraphStorage as Neo4JStorage # import sys diff --git a/lightrag/kg/neo4j.py b/lightrag/kg/neo4j_impl.py similarity index 100% rename from lightrag/kg/neo4j.py rename to lightrag/kg/neo4j_impl.py diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 774c7efe..af367fa0 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -25,7 +25,7 @@ from .storage import ( NetworkXStorage, ) -from .kg.neo4j import ( +from .kg.neo4j_impl import ( GraphStorage as Neo4JStorage ) @@ -58,10 +58,14 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @dataclass class LightRAG: + working_dir: str = field( default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) + kg: str = field(default="NetworkXStorage") + + # text chunking chunk_token_size: int = 1200 chunk_overlap_token_size: int = 100 @@ -99,20 +103,15 @@ class LightRAG: key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage vector_db_storage_cls_kwargs: dict = field(default_factory=dict) - - # module = importlib.import_module('kg.neo4j') - # Neo4JStorage = getattr(module, 'GraphStorage') - if True==True: - print ("using KG") - graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage - else: - graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage enable_llm_cache: bool = True # extension addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json + # def get_configured_KG(self): + # return self.kg + def __post_init__(self): log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) @@ -121,6 +120,10 @@ class LightRAG: _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") + #should move all storage setup here to leverage initial start params attached to self. + print (f"self.kg set to: {self.kg}") + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg] + if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -169,6 +172,12 @@ class LightRAG: self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial(self.llm_model_func, hashing_kv=self.llm_response_cache) ) + def _get_storage_class(self) -> Type[BaseGraphStorage]: + return { + "Neo4JStorage": Neo4JStorage, + "NetworkXStorage": NetworkXStorage, + # "new_kg_here": KGClass + } def insert(self, string_or_strings): loop = always_get_an_event_loop() diff --git a/lightrag/operate.py b/lightrag/operate.py index a0729cd8..8a30f2ca 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -71,6 +71,7 @@ async def _handle_entity_relation_summary( use_prompt = prompt_template.format(**context_base) logger.debug(f"Trigger summary: {entity_or_relation_name}") summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens) + print ("Summarized: {context_base} for entity relationship {} ") return summary @@ -78,6 +79,7 @@ async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, ): + print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}") if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None # add this record as a node in the G @@ -263,6 +265,7 @@ async def extract_entities( async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): nonlocal already_processed, already_entities, already_relations + print (f"kw: processing a single chunk, {chunk_key_dp}") chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] @@ -432,6 +435,7 @@ async def local_query( text_chunks_db, query_param, ) + print (f"got the following context {context} based on prompt keywords {keywords}") if query_param.only_need_context: return context if context is None: @@ -440,6 +444,7 @@ async def local_query( sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) + print (f"local query:{query} local sysprompt:{sys_prompt}") response = await use_model_func( query, system_prompt=sys_prompt, @@ -465,14 +470,20 @@ async def _build_local_query_context( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): + print ("kw1: ENTITIES VDB QUERY**********************************") + results = await entities_vdb.query(query, top_k=query_param.top_k) + print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************") + if not len(results): return None + print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords") node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") + print ("kw4: getting node degrees next for the same entities/nodes") node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) @@ -480,7 +491,7 @@ async def _build_local_query_context( {**n, "entity_name": k["entity_name"], "rank": d} for k, n, d in zip(results, node_datas, node_degrees) if n is not None - ] + ]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram. use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) @@ -718,6 +729,7 @@ async def _build_global_query_context( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): + print ("RELATIONSHIPS VDB QUERY**********************************") results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): @@ -883,12 +895,14 @@ async def hybrid_query( query_param: QueryParam, global_config: dict, ) -> str: + print ("HYBRID QUERY *********") low_level_context = None high_level_context = None use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) + print ( f"kw:kw_prompt: {kw_prompt}") result = await use_model_func(kw_prompt) try: @@ -897,6 +911,8 @@ async def hybrid_query( ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) ll_keywords = ", ".join(ll_keywords) + print (f"High level key words: {hl_keywords}") + print (f"Low level key words: {ll_keywords}") except json.JSONDecodeError: try: result = ( @@ -926,6 +942,8 @@ async def hybrid_query( query_param, ) + print (f"low_level_context: {low_level_context}") + if hl_keywords: high_level_context = await _build_global_query_context( hl_keywords, @@ -935,6 +953,8 @@ async def hybrid_query( text_chunks_db, query_param, ) + print (f"high_level_context: {high_level_context}") + context = combine_contexts(high_level_context, low_level_context) @@ -951,6 +971,7 @@ async def hybrid_query( query, system_prompt=sys_prompt, ) + print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}") if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -1044,10 +1065,13 @@ async def naive_query( ): use_model_func = global_config["llm_model_func"] results = await chunks_vdb.query(query, top_k=query_param.top_k) + print (f"raw chunks from chunks_vdb.query {results}") if not len(results): return PROMPTS["fail_response"] chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) + print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ") + maybe_trun_chunks = truncate_list_by_token_size( chunks, diff --git a/lightrag/storage.py b/lightrag/storage.py index 6e14873b..caa453f7 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -95,6 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] + print (f"Upserting to vector: {list_data}") results = self._client.upsert(datas=list_data) return results @@ -109,6 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage): results = [ {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results ] + print (f"vector db results {results} for query {query}") return results async def index_done_callback(self): diff --git a/testkg.py b/testkg.py index 0c6e7d61..2131f840 100644 --- a/testkg.py +++ b/testkg.py @@ -15,7 +15,8 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, - llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + kg="Neo4JStorage" # llm_model_func=gpt_4o_complete # Optionally, use a stronger model )