securing for production with env vars for creds
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,4 +7,5 @@ lightrag-dev/
|
|||||||
dist/
|
dist/
|
||||||
env/
|
env/
|
||||||
local_neo4jWorkDir/
|
local_neo4jWorkDir/
|
||||||
neo4jWorkDir/
|
neo4jWorkDir/
|
||||||
|
ignore_this.txt
|
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import inspect
|
import inspect
|
||||||
# import package.common.utils as utils
|
|
||||||
from lightrag.utils import load_json, logger, write_json
|
from lightrag.utils import load_json, logger, write_json
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage
|
BaseGraphStorage
|
||||||
@@ -22,27 +21,6 @@ from tenacity import (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
|
|
||||||
# during indexing.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Replace with your actual URI, username, and password
|
|
||||||
#local
|
|
||||||
URI = "neo4j://localhost:7687"
|
|
||||||
USERNAME = "neo4j"
|
|
||||||
PASSWORD = "password"
|
|
||||||
|
|
||||||
#aura
|
|
||||||
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
|
|
||||||
# USERNAME = "neo4j"
|
|
||||||
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
|
|
||||||
# Create a driver object
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphStorage(BaseGraphStorage):
|
class GraphStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -51,6 +29,15 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# self._graph = preloaded_graph or nx.Graph()
|
# self._graph = preloaded_graph or nx.Graph()
|
||||||
|
credetial_parts = ['URI', 'USERNAME','PASSWORD']
|
||||||
|
credentials_set = all(x in os.environ for x in credetial_parts )
|
||||||
|
if credentials_set:
|
||||||
|
URI = os.environ["URI"]
|
||||||
|
USERNAME = os.environ["USERNAME"]
|
||||||
|
PASSWORD = os.environ["PASSWORD"]
|
||||||
|
else:
|
||||||
|
raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
|
||||||
|
|
||||||
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
||||||
self._node_embed_algorithms = {
|
self._node_embed_algorithms = {
|
||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
@@ -65,7 +52,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
|
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
|
||||||
result = tx.run(query)
|
result = tx.run(query)
|
||||||
single_result = result.single()
|
single_result = result.single()
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,7 +77,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
# if result.single() == None:
|
# if result.single() == None:
|
||||||
# print (f"this should not happen: ---- {label1}/{label2} {query}")
|
# print (f"this should not happen: ---- {label1}/{label2} {query}")
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -111,7 +98,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
result = session.run(query)
|
result = session.run(query)
|
||||||
for record in result:
|
for record in result:
|
||||||
result = record["n"]
|
result = record["n"]
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
@@ -133,7 +120,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
record = result.single()
|
record = result.single()
|
||||||
if record:
|
if record:
|
||||||
edge_count = record["totalEdgeCount"]
|
edge_count = record["totalEdgeCount"]
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
|
||||||
)
|
)
|
||||||
return edge_count
|
return edge_count
|
||||||
@@ -154,7 +141,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
RETURN count(r) AS degree"""
|
RETURN count(r) AS degree"""
|
||||||
result = session.run(query)
|
result = session.run(query)
|
||||||
record = result.single()
|
record = result.single()
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
|
||||||
)
|
)
|
||||||
return record["degree"]
|
return record["degree"]
|
||||||
@@ -183,7 +170,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
record = result.single()
|
record = result.single()
|
||||||
if record:
|
if record:
|
||||||
result = dict(record["edge_properties"])
|
result = dict(record["edge_properties"])
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
@@ -254,7 +241,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
# if source_label and target_label:
|
# if source_label and target_label:
|
||||||
# connections.append((source_label, target_label))
|
# connections.append((source_label, target_label))
|
||||||
|
|
||||||
# logger.info(
|
# logger.debug(
|
||||||
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
||||||
# )
|
# )
|
||||||
# return connections
|
# return connections
|
||||||
@@ -308,7 +295,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
result = tx.run(query, properties=properties)
|
result = tx.run(query, properties=properties)
|
||||||
record = result.single()
|
record = result.single()
|
||||||
if record:
|
if record:
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
|
||||||
)
|
)
|
||||||
return dict(record["n"])
|
return dict(record["n"])
|
||||||
@@ -364,7 +351,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
result = tx.run(query, properties=edge_properties)
|
result = tx.run(query, properties=edge_properties)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
||||||
)
|
)
|
||||||
return result.single()
|
return result.single()
|
||||||
@@ -385,7 +372,7 @@ class GraphStorage(BaseGraphStorage):
|
|||||||
with self._driver.session() as session:
|
with self._driver.session() as session:
|
||||||
#Define the Cypher query
|
#Define the Cypher query
|
||||||
options = self.global_config["node2vec_params"]
|
options = self.global_config["node2vec_params"]
|
||||||
logger.info(f"building embeddings with options {options}")
|
logger.debug(f"building embeddings with options {options}")
|
||||||
query = f"""CALL gds.node2vec.write('91fbae6c', {
|
query = f"""CALL gds.node2vec.write('91fbae6c', {
|
||||||
options
|
options
|
||||||
})
|
})
|
||||||
|
@@ -28,6 +28,13 @@ from .storage import (
|
|||||||
from .kg.neo4j_impl import (
|
from .kg.neo4j_impl import (
|
||||||
GraphStorage as Neo4JStorage
|
GraphStorage as Neo4JStorage
|
||||||
)
|
)
|
||||||
|
#future KG integrations
|
||||||
|
|
||||||
|
# from .kg.ArangoDB_impl import (
|
||||||
|
# GraphStorage as ArangoDBStorage
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
EmbeddingFunc,
|
EmbeddingFunc,
|
||||||
@@ -64,7 +71,11 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
kg: str = field(default="NetworkXStorage")
|
kg: str = field(default="NetworkXStorage")
|
||||||
|
|
||||||
|
current_log_level = logger.level
|
||||||
|
log_level: str = field(default=current_log_level)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# text chunking
|
# text chunking
|
||||||
chunk_token_size: int = 1200
|
chunk_token_size: int = 1200
|
||||||
@@ -115,13 +126,14 @@ class LightRAG:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||||||
set_logger(log_file)
|
set_logger(log_file)
|
||||||
|
logger.setLevel(self.log_level)
|
||||||
|
|
||||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||||
|
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
#should move all storage setup here to leverage initial start params attached to self.
|
#should move all storage setup here to leverage initial start params attached to self.
|
||||||
print (f"self.kg set to: {self.kg}")
|
|
||||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
|
||||||
|
|
||||||
if not os.path.exists(self.working_dir):
|
if not os.path.exists(self.working_dir):
|
||||||
@@ -176,7 +188,7 @@ class LightRAG:
|
|||||||
return {
|
return {
|
||||||
"Neo4JStorage": Neo4JStorage,
|
"Neo4JStorage": Neo4JStorage,
|
||||||
"NetworkXStorage": NetworkXStorage,
|
"NetworkXStorage": NetworkXStorage,
|
||||||
# "new_kg_here": KGClass
|
# "ArangoDBStorage": ArangoDBStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
def insert(self, string_or_strings):
|
def insert(self, string_or_strings):
|
||||||
|
@@ -71,7 +71,6 @@ async def _handle_entity_relation_summary(
|
|||||||
use_prompt = prompt_template.format(**context_base)
|
use_prompt = prompt_template.format(**context_base)
|
||||||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||||||
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
||||||
print ("Summarized: {context_base} for entity relationship {} ")
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +78,6 @@ async def _handle_single_entity_extraction(
|
|||||||
record_attributes: list[str],
|
record_attributes: list[str],
|
||||||
chunk_key: str,
|
chunk_key: str,
|
||||||
):
|
):
|
||||||
print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}")
|
|
||||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||||
return None
|
return None
|
||||||
# add this record as a node in the G
|
# add this record as a node in the G
|
||||||
@@ -265,7 +263,6 @@ async def extract_entities(
|
|||||||
|
|
||||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||||
nonlocal already_processed, already_entities, already_relations
|
nonlocal already_processed, already_entities, already_relations
|
||||||
print (f"kw: processing a single chunk, {chunk_key_dp}")
|
|
||||||
chunk_key = chunk_key_dp[0]
|
chunk_key = chunk_key_dp[0]
|
||||||
chunk_dp = chunk_key_dp[1]
|
chunk_dp = chunk_key_dp[1]
|
||||||
content = chunk_dp["content"]
|
content = chunk_dp["content"]
|
||||||
@@ -435,7 +432,6 @@ async def local_query(
|
|||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
print (f"got the following context {context} based on prompt keywords {keywords}")
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
return context
|
return context
|
||||||
if context is None:
|
if context is None:
|
||||||
@@ -444,7 +440,6 @@ async def local_query(
|
|||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context, response_type=query_param.response_type
|
context_data=context, response_type=query_param.response_type
|
||||||
)
|
)
|
||||||
print (f"local query:{query} local sysprompt:{sys_prompt}")
|
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
@@ -470,20 +465,16 @@ async def _build_local_query_context(
|
|||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
print ("kw1: ENTITIES VDB QUERY**********************************")
|
|
||||||
|
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||||
print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************")
|
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return None
|
return None
|
||||||
print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords")
|
|
||||||
node_datas = await asyncio.gather(
|
node_datas = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||||
)
|
)
|
||||||
if not all([n is not None for n in node_datas]):
|
if not all([n is not None for n in node_datas]):
|
||||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||||
print ("kw4: getting node degrees next for the same entities/nodes")
|
|
||||||
node_degrees = await asyncio.gather(
|
node_degrees = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||||
)
|
)
|
||||||
@@ -729,7 +720,6 @@ async def _build_global_query_context(
|
|||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
print ("RELATIONSHIPS VDB QUERY**********************************")
|
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
@@ -895,14 +885,12 @@ async def hybrid_query(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
print ("HYBRID QUERY *********")
|
|
||||||
low_level_context = None
|
low_level_context = None
|
||||||
high_level_context = None
|
high_level_context = None
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
print ( f"kw:kw_prompt: {kw_prompt}")
|
|
||||||
|
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
try:
|
try:
|
||||||
@@ -911,8 +899,6 @@ async def hybrid_query(
|
|||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
hl_keywords = ", ".join(hl_keywords)
|
hl_keywords = ", ".join(hl_keywords)
|
||||||
ll_keywords = ", ".join(ll_keywords)
|
ll_keywords = ", ".join(ll_keywords)
|
||||||
print (f"High level key words: {hl_keywords}")
|
|
||||||
print (f"Low level key words: {ll_keywords}")
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = (
|
result = (
|
||||||
@@ -942,7 +928,6 @@ async def hybrid_query(
|
|||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
print (f"low_level_context: {low_level_context}")
|
|
||||||
|
|
||||||
if hl_keywords:
|
if hl_keywords:
|
||||||
high_level_context = await _build_global_query_context(
|
high_level_context = await _build_global_query_context(
|
||||||
@@ -953,7 +938,6 @@ async def hybrid_query(
|
|||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
print (f"high_level_context: {high_level_context}")
|
|
||||||
|
|
||||||
|
|
||||||
context = combine_contexts(high_level_context, low_level_context)
|
context = combine_contexts(high_level_context, low_level_context)
|
||||||
@@ -971,7 +955,6 @@ async def hybrid_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}")
|
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = (
|
response = (
|
||||||
response.replace(sys_prompt, "")
|
response.replace(sys_prompt, "")
|
||||||
@@ -1065,12 +1048,10 @@ async def naive_query(
|
|||||||
):
|
):
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||||
print (f"raw chunks from chunks_vdb.query {results}")
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
chunks_ids = [r["id"] for r in results]
|
chunks_ids = [r["id"] for r in results]
|
||||||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||||||
print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ")
|
|
||||||
|
|
||||||
|
|
||||||
maybe_trun_chunks = truncate_list_by_token_size(
|
maybe_trun_chunks = truncate_list_by_token_size(
|
||||||
|
@@ -16,12 +16,13 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
|
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
|
||||||
kg="Neo4JStorage"
|
kg="Neo4JStorage",
|
||||||
|
log_level="INFO"
|
||||||
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
||||||
)
|
)
|
||||||
|
|
||||||
with open("./book.txt") as f:
|
# with open("./book.txt") as f:
|
||||||
rag.insert(f.read())
|
# rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||||
|
Reference in New Issue
Block a user