securing for production with env vars for creds

This commit is contained in:
Ken Wiltshire
2024-11-01 11:01:50 -04:00
parent e966a14418
commit b41d990fd6
5 changed files with 40 additions and 58 deletions

View File

@@ -5,7 +5,6 @@ from dataclasses import dataclass
from typing import Any, Union, cast
import numpy as np
import inspect
# import package.common.utils as utils
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
@@ -22,27 +21,6 @@ from tenacity import (
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
# during indexing.
# Replace with your actual URI, username, and password
#local
URI = "neo4j://localhost:7687"
USERNAME = "neo4j"
PASSWORD = "password"
#aura
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
# USERNAME = "neo4j"
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
# Create a driver object
@dataclass
class GraphStorage(BaseGraphStorage):
@staticmethod
@@ -51,6 +29,15 @@ class GraphStorage(BaseGraphStorage):
def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph()
credetial_parts = ['URI', 'USERNAME','PASSWORD']
credentials_set = all(x in os.environ for x in credetial_parts )
if credentials_set:
URI = os.environ["URI"]
USERNAME = os.environ["USERNAME"]
PASSWORD = os.environ["PASSWORD"]
else:
raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
@@ -65,7 +52,7 @@ class GraphStorage(BaseGraphStorage):
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
result = tx.run(query)
single_result = result.single()
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
@@ -90,7 +77,7 @@ class GraphStorage(BaseGraphStorage):
# if result.single() == None:
# print (f"this should not happen: ---- {label1}/{label2} {query}")
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
@@ -111,7 +98,7 @@ class GraphStorage(BaseGraphStorage):
result = session.run(query)
for record in result:
result = record["n"]
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
return result
@@ -133,7 +120,7 @@ class GraphStorage(BaseGraphStorage):
record = result.single()
if record:
edge_count = record["totalEdgeCount"]
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
return edge_count
@@ -154,7 +141,7 @@ class GraphStorage(BaseGraphStorage):
RETURN count(r) AS degree"""
result = session.run(query)
record = result.single()
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
)
return record["degree"]
@@ -183,7 +170,7 @@ class GraphStorage(BaseGraphStorage):
record = result.single()
if record:
result = dict(record["edge_properties"])
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
return result
@@ -254,7 +241,7 @@ class GraphStorage(BaseGraphStorage):
# if source_label and target_label:
# connections.append((source_label, target_label))
# logger.info(
# logger.debug(
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
# )
# return connections
@@ -308,7 +295,7 @@ class GraphStorage(BaseGraphStorage):
result = tx.run(query, properties=properties)
record = result.single()
if record:
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
)
return dict(record["n"])
@@ -364,7 +351,7 @@ class GraphStorage(BaseGraphStorage):
"""
result = tx.run(query, properties=edge_properties)
logger.info(
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
)
return result.single()
@@ -385,7 +372,7 @@ class GraphStorage(BaseGraphStorage):
with self._driver.session() as session:
#Define the Cypher query
options = self.global_config["node2vec_params"]
logger.info(f"building embeddings with options {options}")
logger.debug(f"building embeddings with options {options}")
query = f"""CALL gds.node2vec.write('91fbae6c', {
options
})