securing for production with env vars for creds
This commit is contained in:
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import numpy as np
|
||||
import inspect
|
||||
# import package.common.utils as utils
|
||||
from lightrag.utils import load_json, logger, write_json
|
||||
from ..base import (
|
||||
BaseGraphStorage
|
||||
@@ -22,27 +21,6 @@ from tenacity import (
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
|
||||
# during indexing.
|
||||
|
||||
|
||||
|
||||
|
||||
# Replace with your actual URI, username, and password
|
||||
#local
|
||||
URI = "neo4j://localhost:7687"
|
||||
USERNAME = "neo4j"
|
||||
PASSWORD = "password"
|
||||
|
||||
#aura
|
||||
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
|
||||
# USERNAME = "neo4j"
|
||||
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
|
||||
# Create a driver object
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -51,6 +29,15 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
def __post_init__(self):
|
||||
# self._graph = preloaded_graph or nx.Graph()
|
||||
credetial_parts = ['URI', 'USERNAME','PASSWORD']
|
||||
credentials_set = all(x in os.environ for x in credetial_parts )
|
||||
if credentials_set:
|
||||
URI = os.environ["URI"]
|
||||
USERNAME = os.environ["USERNAME"]
|
||||
PASSWORD = os.environ["PASSWORD"]
|
||||
else:
|
||||
raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
|
||||
|
||||
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
@@ -65,7 +52,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
|
||||
result = tx.run(query)
|
||||
single_result = result.single()
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
||||
)
|
||||
|
||||
@@ -90,7 +77,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
# if result.single() == None:
|
||||
# print (f"this should not happen: ---- {label1}/{label2} {query}")
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
||||
)
|
||||
|
||||
@@ -111,7 +98,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
result = session.run(query)
|
||||
for record in result:
|
||||
result = record["n"]
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
||||
)
|
||||
return result
|
||||
@@ -133,7 +120,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
record = result.single()
|
||||
if record:
|
||||
edge_count = record["totalEdgeCount"]
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
|
||||
)
|
||||
return edge_count
|
||||
@@ -154,7 +141,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
RETURN count(r) AS degree"""
|
||||
result = session.run(query)
|
||||
record = result.single()
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
|
||||
)
|
||||
return record["degree"]
|
||||
@@ -183,7 +170,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
record = result.single()
|
||||
if record:
|
||||
result = dict(record["edge_properties"])
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
||||
)
|
||||
return result
|
||||
@@ -254,7 +241,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
# if source_label and target_label:
|
||||
# connections.append((source_label, target_label))
|
||||
|
||||
# logger.info(
|
||||
# logger.debug(
|
||||
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
||||
# )
|
||||
# return connections
|
||||
@@ -308,7 +295,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
result = tx.run(query, properties=properties)
|
||||
record = result.single()
|
||||
if record:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
|
||||
)
|
||||
return dict(record["n"])
|
||||
@@ -364,7 +351,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
|
||||
result = tx.run(query, properties=edge_properties)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
||||
)
|
||||
return result.single()
|
||||
@@ -385,7 +372,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
with self._driver.session() as session:
|
||||
#Define the Cypher query
|
||||
options = self.global_config["node2vec_params"]
|
||||
logger.info(f"building embeddings with options {options}")
|
||||
logger.debug(f"building embeddings with options {options}")
|
||||
query = f"""CALL gds.node2vec.write('91fbae6c', {
|
||||
options
|
||||
})
|
||||
|
Reference in New Issue
Block a user