diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 5ffbf2bc..03b1bbcb 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -69,12 +69,28 @@ class Neo4JStorage(BaseGraphStorage): config.get("neo4j", "connection_pool_size", fallback=800), ) ) + CONNECTION_TIMEOUT = float( + os.environ.get( + "NEO4J_CONNECTION_TIMEOUT", + config.get("neo4j", "connection_timeout", fallback=60.0), + ), + ) + CONNECTION_ACQUISITION_TIMEOUT = float( + os.environ.get( + "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", + config.get("neo4j", "connection_acquisition_timeout", fallback=60.0), + ), + ) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) ) self._driver: AsyncDriver = AsyncGraphDatabase.driver( - URI, auth=(USERNAME, PASSWORD) + URI, + auth=(USERNAME, PASSWORD), + max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, + connection_timeout=CONNECTION_TIMEOUT, + connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, ) # Try to connect to the database @@ -82,6 +98,8 @@ class Neo4JStorage(BaseGraphStorage): URI, auth=(USERNAME, PASSWORD), max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, + connection_timeout=CONNECTION_TIMEOUT, + connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, ) as _sync_driver: for database in (DATABASE, None): self._DATABASE = database @@ -278,14 +296,16 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) record = await result.single() - if record and "edge_properties" in record: + if record: try: result = dict(record["edge_properties"]) + logger.info(f"Result: {result}") # Ensure required keys exist with defaults required_keys = { "weight": 0.0, "source_id": None, - "target_id": None, + "description": None, + "keywords": None, } for key, default_value in required_keys.items(): if key not in result: @@ -305,20 +325,35 @@ class Neo4JStorage(BaseGraphStorage): f"and {entity_name_label_target}: {str(e)}" ) # Return default edge properties on error - return {"weight": 0.0, "source_id": None, "target_id": None} + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } logger.debug( f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) # Return default edge properties when no edge found - return {"weight": 0.0, "source_id": None, "target_id": None} + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } except Exception as e: logger.error( f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) # Return default edge properties on error - return {"weight": 0.0, "source_id": None, "target_id": None} + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_label = source_node_id.strip('"')