diff --git a/.gitignore b/.gitignore index e04f6472..39fa6515 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ lightrag-dev/ dist/ env/ local_neo4jWorkDir/ -neo4jWorkDir/ \ No newline at end of file +neo4jWorkDir/ +ignore_this.txt \ No newline at end of file diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index ddb78efe..6db885d6 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import Any, Union, cast import numpy as np import inspect -# import package.common.utils as utils from lightrag.utils import load_json, logger, write_json from ..base import ( BaseGraphStorage @@ -22,27 +21,6 @@ from tenacity import ( - - -# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address" -# during indexing. - - - - -# Replace with your actual URI, username, and password -#local -URI = "neo4j://localhost:7687" -USERNAME = "neo4j" -PASSWORD = "password" - -#aura -# URI = "neo4j+s://91fbae6c.databases.neo4j.io" -# USERNAME = "neo4j" -# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw" -# Create a driver object - - @dataclass class GraphStorage(BaseGraphStorage): @staticmethod @@ -51,6 +29,15 @@ class GraphStorage(BaseGraphStorage): def __post_init__(self): # self._graph = preloaded_graph or nx.Graph() + credetial_parts = ['URI', 'USERNAME','PASSWORD'] + credentials_set = all(x in os.environ for x in credetial_parts ) + if credentials_set: + URI = os.environ["URI"] + USERNAME = os.environ["USERNAME"] + PASSWORD = os.environ["PASSWORD"] + else: + raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment") + self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) self._node_embed_algorithms = { "node2vec": self._node2vec_embed, @@ -65,7 +52,7 @@ class GraphStorage(BaseGraphStorage): query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists" result = tx.run(query) single_result = result.single() - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' ) @@ -90,7 +77,7 @@ class GraphStorage(BaseGraphStorage): # if result.single() == None: # print (f"this should not happen: ---- {label1}/{label2} {query}") - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' ) @@ -111,7 +98,7 @@ class GraphStorage(BaseGraphStorage): result = session.run(query) for record in result: result = record["n"] - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}' ) return result @@ -133,7 +120,7 @@ class GraphStorage(BaseGraphStorage): record = result.single() if record: edge_count = record["totalEdgeCount"] - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}' ) return edge_count @@ -154,7 +141,7 @@ class GraphStorage(BaseGraphStorage): RETURN count(r) AS degree""" result = session.run(query) record = result.single() - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}' ) return record["degree"] @@ -183,7 +170,7 @@ class GraphStorage(BaseGraphStorage): record = result.single() if record: result = dict(record["edge_properties"]) - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}' ) return result @@ -254,7 +241,7 @@ class GraphStorage(BaseGraphStorage): # if source_label and target_label: # connections.append((source_label, target_label)) - # logger.info( + # logger.debug( # f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}' # ) # return connections @@ -308,7 +295,7 @@ class GraphStorage(BaseGraphStorage): result = tx.run(query, properties=properties) record = result.single() if record: - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}' ) return dict(record["n"]) @@ -364,7 +351,7 @@ class GraphStorage(BaseGraphStorage): """ result = tx.run(query, properties=edge_properties) - logger.info( + logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}' ) return result.single() @@ -385,7 +372,7 @@ class GraphStorage(BaseGraphStorage): with self._driver.session() as session: #Define the Cypher query options = self.global_config["node2vec_params"] - logger.info(f"building embeddings with options {options}") + logger.debug(f"building embeddings with options {options}") query = f"""CALL gds.node2vec.write('91fbae6c', { options }) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index af367fa0..d9b12a99 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -28,6 +28,13 @@ from .storage import ( from .kg.neo4j_impl import ( GraphStorage as Neo4JStorage ) +#future KG integrations + +# from .kg.ArangoDB_impl import ( +# GraphStorage as ArangoDBStorage +# ) + + from .utils import ( EmbeddingFunc, @@ -64,7 +71,11 @@ class LightRAG: ) kg: str = field(default="NetworkXStorage") - + + current_log_level = logger.level + log_level: str = field(default=current_log_level) + + # text chunking chunk_token_size: int = 1200 @@ -115,13 +126,14 @@ class LightRAG: def __post_init__(self): log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) + logger.setLevel(self.log_level) + logger.info(f"Logger initialized for working directory: {self.working_dir}") _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") #should move all storage setup here to leverage initial start params attached to self. - print (f"self.kg set to: {self.kg}") self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg] if not os.path.exists(self.working_dir): @@ -176,7 +188,7 @@ class LightRAG: return { "Neo4JStorage": Neo4JStorage, "NetworkXStorage": NetworkXStorage, - # "new_kg_here": KGClass + # "ArangoDBStorage": ArangoDBStorage } def insert(self, string_or_strings): diff --git a/lightrag/operate.py b/lightrag/operate.py index 8a30f2ca..ebec5c3f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -71,7 +71,6 @@ async def _handle_entity_relation_summary( use_prompt = prompt_template.format(**context_base) logger.debug(f"Trigger summary: {entity_or_relation_name}") summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens) - print ("Summarized: {context_base} for entity relationship {} ") return summary @@ -79,7 +78,6 @@ async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, ): - print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}") if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None # add this record as a node in the G @@ -265,7 +263,6 @@ async def extract_entities( async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): nonlocal already_processed, already_entities, already_relations - print (f"kw: processing a single chunk, {chunk_key_dp}") chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] @@ -435,7 +432,6 @@ async def local_query( text_chunks_db, query_param, ) - print (f"got the following context {context} based on prompt keywords {keywords}") if query_param.only_need_context: return context if context is None: @@ -444,7 +440,6 @@ async def local_query( sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) - print (f"local query:{query} local sysprompt:{sys_prompt}") response = await use_model_func( query, system_prompt=sys_prompt, @@ -470,20 +465,16 @@ async def _build_local_query_context( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): - print ("kw1: ENTITIES VDB QUERY**********************************") results = await entities_vdb.query(query, top_k=query_param.top_k) - print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************") if not len(results): return None - print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords") node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") - print ("kw4: getting node degrees next for the same entities/nodes") node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) @@ -729,7 +720,6 @@ async def _build_global_query_context( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): - print ("RELATIONSHIPS VDB QUERY**********************************") results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): @@ -895,14 +885,12 @@ async def hybrid_query( query_param: QueryParam, global_config: dict, ) -> str: - print ("HYBRID QUERY *********") low_level_context = None high_level_context = None use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) - print ( f"kw:kw_prompt: {kw_prompt}") result = await use_model_func(kw_prompt) try: @@ -911,8 +899,6 @@ async def hybrid_query( ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) ll_keywords = ", ".join(ll_keywords) - print (f"High level key words: {hl_keywords}") - print (f"Low level key words: {ll_keywords}") except json.JSONDecodeError: try: result = ( @@ -942,7 +928,6 @@ async def hybrid_query( query_param, ) - print (f"low_level_context: {low_level_context}") if hl_keywords: high_level_context = await _build_global_query_context( @@ -953,7 +938,6 @@ async def hybrid_query( text_chunks_db, query_param, ) - print (f"high_level_context: {high_level_context}") context = combine_contexts(high_level_context, low_level_context) @@ -971,7 +955,6 @@ async def hybrid_query( query, system_prompt=sys_prompt, ) - print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}") if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -1065,12 +1048,10 @@ async def naive_query( ): use_model_func = global_config["llm_model_func"] results = await chunks_vdb.query(query, top_k=query_param.top_k) - print (f"raw chunks from chunks_vdb.query {results}") if not len(results): return PROMPTS["fail_response"] chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) - print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ") maybe_trun_chunks = truncate_list_by_token_size( diff --git a/testkg.py b/testkg.py index 2131f840..c90edde9 100644 --- a/testkg.py +++ b/testkg.py @@ -16,12 +16,13 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model - kg="Neo4JStorage" + kg="Neo4JStorage", + log_level="INFO" # llm_model_func=gpt_4o_complete # Optionally, use a stronger model ) -with open("./book.txt") as f: - rag.insert(f.read()) +# with open("./book.txt") as f: +# rag.insert(f.read()) # Perform naive search print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))