Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -69,12 +69,28 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
config.get("neo4j", "connection_pool_size", fallback=800),
|
config.get("neo4j", "connection_pool_size", fallback=800),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
CONNECTION_TIMEOUT = float(
|
||||||
|
os.environ.get(
|
||||||
|
"NEO4J_CONNECTION_TIMEOUT",
|
||||||
|
config.get("neo4j", "connection_timeout", fallback=60.0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
CONNECTION_ACQUISITION_TIMEOUT = float(
|
||||||
|
os.environ.get(
|
||||||
|
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
|
||||||
|
config.get("neo4j", "connection_acquisition_timeout", fallback=60.0),
|
||||||
|
),
|
||||||
|
)
|
||||||
DATABASE = os.environ.get(
|
DATABASE = os.environ.get(
|
||||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||||
URI, auth=(USERNAME, PASSWORD)
|
URI,
|
||||||
|
auth=(USERNAME, PASSWORD),
|
||||||
|
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||||
|
connection_timeout=CONNECTION_TIMEOUT,
|
||||||
|
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to connect to the database
|
# Try to connect to the database
|
||||||
@@ -82,6 +98,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
URI,
|
URI,
|
||||||
auth=(USERNAME, PASSWORD),
|
auth=(USERNAME, PASSWORD),
|
||||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||||
|
connection_timeout=CONNECTION_TIMEOUT,
|
||||||
|
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||||
) as _sync_driver:
|
) as _sync_driver:
|
||||||
for database in (DATABASE, None):
|
for database in (DATABASE, None):
|
||||||
self._DATABASE = database
|
self._DATABASE = database
|
||||||
@@ -278,14 +296,16 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
if record and "edge_properties" in record:
|
if record:
|
||||||
try:
|
try:
|
||||||
result = dict(record["edge_properties"])
|
result = dict(record["edge_properties"])
|
||||||
|
logger.info(f"Result: {result}")
|
||||||
# Ensure required keys exist with defaults
|
# Ensure required keys exist with defaults
|
||||||
required_keys = {
|
required_keys = {
|
||||||
"weight": 0.0,
|
"weight": 0.0,
|
||||||
"source_id": None,
|
"source_id": None,
|
||||||
"target_id": None,
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
}
|
}
|
||||||
for key, default_value in required_keys.items():
|
for key, default_value in required_keys.items():
|
||||||
if key not in result:
|
if key not in result:
|
||||||
@@ -305,20 +325,35 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
f"and {entity_name_label_target}: {str(e)}"
|
f"and {entity_name_label_target}: {str(e)}"
|
||||||
)
|
)
|
||||||
# Return default edge properties on error
|
# Return default edge properties on error
|
||||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
return {
|
||||||
|
"weight": 0.0,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
"source_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||||
)
|
)
|
||||||
# Return default edge properties when no edge found
|
# Return default edge properties when no edge found
|
||||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
return {
|
||||||
|
"weight": 0.0,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
"source_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
# Return default edge properties on error
|
# Return default edge properties on error
|
||||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
return {
|
||||||
|
"weight": 0.0,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
"source_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
|
Reference in New Issue
Block a user