Merge branch 'main' into main

This commit is contained in:
zrguo
2024-11-07 14:54:15 +08:00
committed by GitHub
12 changed files with 225 additions and 4803 deletions

View File

@@ -498,6 +498,10 @@ pip install fastapi uvicorn pydantic
2. Set up your environment variables: 2. Set up your environment variables:
```bash ```bash
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default" export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
export OPENAI_BASE_URL="Your OpenAI API base URL" # Optional: Defaults to "https://api.openai.com/v1"
export OPENAI_API_KEY="Your OpenAI API key" # Required
export LLM_MODEL="Your LLM model" # Optional: Defaults to "gpt-4o-mini"
export EMBEDDING_MODEL="Your embedding model" # Optional: Defaults to "text-embedding-3-large"
``` ```
3. Run the API server: 3. Run the API server:
@@ -522,7 +526,8 @@ The API server provides the following endpoints:
```json ```json
{ {
"query": "Your question here", "query": "Your question here",
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid" "mode": "hybrid", // Can be "naive", "local", "global", or "hybrid"
"only_need_context": true // Optional: Defaults to false, if true, only the referenced context will be returned, otherwise the llm answer will be returned
} }
``` ```
- **Example:** - **Example:**

View File

@@ -1,4 +1,4 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
@@ -18,22 +18,28 @@ app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory # Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}") print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
# LLM model function # LLM model function
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o-mini", LLM_MODEL,
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs, **kwargs,
) )
@@ -44,37 +50,41 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray: async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding( return await openai_embedding(
texts, texts,
model="text-embedding-3-large", model=EMBEDDING_MODEL,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
) )
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance # Initialize RAG instance
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()),
embedding_dim=3072, max_token_size=8192, func=embedding_func max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
), func=embedding_func),
) )
# Data models # Data models
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
mode: str = "hybrid" mode: str = "hybrid"
only_need_context: bool = False
class InsertRequest(BaseModel): class InsertRequest(BaseModel):
text: str text: str
class InsertFileRequest(BaseModel):
file_path: str
class Response(BaseModel): class Response(BaseModel):
status: str status: str
data: Optional[str] = None data: Optional[str] = None
@@ -89,7 +99,8 @@ async def query_endpoint(request: QueryRequest):
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
result = await loop.run_in_executor( result = await loop.run_in_executor(
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) None, lambda: rag.query(request.query,
param=QueryParam(mode=request.mode, only_need_context=request.only_need_context))
) )
return Response(status="success", data=result) return Response(status="success", data=result)
except Exception as e: except Exception as e:
@@ -107,30 +118,22 @@ async def insert_endpoint(request: InsertRequest):
@app.post("/insert_file", response_model=Response) @app.post("/insert_file", response_model=Response)
async def insert_file(request: InsertFileRequest): async def insert_file(file: UploadFile = File(...)):
try: try:
# Check if file exists file_content = await file.read()
if not os.path.exists(request.file_path):
raise HTTPException(
status_code=404, detail=f"File not found: {request.file_path}"
)
# Read file content # Read file content
try: try:
with open(request.file_path, "r", encoding="utf-8") as f: content = file_content.decode("utf-8")
content = f.read()
except UnicodeDecodeError: except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings # If UTF-8 decoding fails, try other encodings
with open(request.file_path, "r", encoding="gbk") as f: content = file_content.decode("gbk")
content = f.read()
# Insert file content # Insert file content
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content)) await loop.run_in_executor(None, lambda: rag.insert(content))
return Response( return Response(
status="success", status="success",
message=f"File content from {request.file_path} inserted successfully", message=f"File content from {file.filename} inserted successfully",
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,28 +1,34 @@
import networkx as nx import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml') G = nx.read_graphml("./dickensTestEmbedcall/graph_chunk_entity_relation.graphml")
def get_all_edges_and_nodes(G): def get_all_edges_and_nodes(G):
# Get all edges and their properties # Get all edges and their properties
edges_with_properties = [] edges_with_properties = []
for u, v, data in G.edges(data=True): for u, v, data in G.edges(data=True):
edges_with_properties.append({ edges_with_properties.append(
'start': u, {
'end': v, "start": u,
'label': data.get('label', ''), # Assuming 'label' is used for edge type "end": v,
'properties': data, "label": data.get(
'start_node_properties': G.nodes[u], "label", ""
'end_node_properties': G.nodes[v] ), # Assuming 'label' is used for edge type
}) "properties": data,
"start_node_properties": G.nodes[u],
"end_node_properties": G.nodes[v],
}
)
return edges_with_properties return edges_with_properties
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j # Assume G is your NetworkX graph loaded from Neo4j
all_edges = get_all_edges_and_nodes(G) all_edges = get_all_edges_and_nodes(G)
# Print all edges and node properties # Print all edges and node properties
for edge in all_edges: for edge in all_edges:
print(f"Edge Label: {edge['label']}") print(f"Edge Label: {edge['label']}")
@@ -31,4 +37,4 @@ if __name__ == "__main__":
print(f"Start Node Properties: {edge['start_node_properties']}") print(f"Start Node Properties: {edge['start_node_properties']}")
print(f"End Node: {edge['end']}") print(f"End Node: {edge['end']}")
print(f"End Node Properties: {edge['end_node_properties']}") print(f"End Node Properties: {edge['end_node_properties']}")
print("---") print("---")

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,16 @@
import asyncio import asyncio
import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast, Tuple, List, Dict from typing import Any, Union, Tuple, List, Dict
import numpy as np
import inspect import inspect
from lightrag.utils import load_json, logger, write_json from lightrag.utils import logger
from ..base import ( from ..base import BaseGraphStorage
BaseGraphStorage from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
) )
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import ( from tenacity import (
@@ -26,7 +25,7 @@ from tenacity import (
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production") print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config): def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config) super().__init__(namespace=namespace, global_config=global_config)
@@ -35,7 +34,9 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"] URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None return None
def __post_init__(self): def __post_init__(self):
@@ -43,59 +44,54 @@ class Neo4JStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def close(self): async def close(self):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
self._driver = None self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
async def index_done_callback(self): async def index_done_callback(self):
print ("KG successfully indexed.") print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" query = (
result = await session.run(query) f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
result = await session.run(query)
single_result = await result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
) )
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('\"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session() as session:
query = ( query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
) )
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
) )
return single_result["edgeExists"] return single_result["edgeExists"]
def close(self):
self._driver.close()
def close(self):
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session: async with self._driver.session() as session:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
record = await result.single() record = await result.single()
@@ -103,54 +99,51 @@ class Neo4JStorage(BaseGraphStorage):
node = record["n"] node = record["n"]
node_dict = dict(node) node_dict = dict(node)
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}' f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
) )
return node_dict return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session() as session:
query = f""" query = f"""
MATCH (n:`{entity_name_label}`) MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount RETURN COUNT{{ (n)--() }} AS totalEdgeCount
""" """
result = await session.run(query) result = await session.run(query)
record = await result.single() record = await result.single()
if record: if record:
edge_count = record["totalEdgeCount"] edge_count = record["totalEdgeCount"]
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}' f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
) )
return edge_count return edge_count
else: else:
return None return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"') entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('\"') entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source) src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target) trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition # Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}' f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
) )
return degrees return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: ) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('\"') entity_name_label_target = target_node_id.strip('"')
""" """
Find all edges between nodes of two given labels Find all edges between nodes of two given labels
@@ -161,28 +154,30 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
async with self._driver.session() as session: async with self._driver.session() as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target) """.format(
entity_name_label_source=entity_name_label_source,
result = await session.run(query) entity_name_label_target=entity_name_label_target,
)
result = await session.run(query)
record = await result.single() record = await result.single()
if record: if record:
result = dict(record["edge_properties"]) result = dict(record["edge_properties"])
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}' f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
) )
return result return result
else: else:
return None return None
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"') node_label = source_node_id.strip('"')
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
@@ -190,26 +185,37 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session() as session: async with self._driver.session() as session:
results = await session.run(query) results = await session.run(query)
edges = [] edges = []
async for record in results: async for record in results:
source_node = record['n'] source_node = record["n"]
connected_node = record['connected'] connected_node = record["connected"]
source_label = list(source_node.labels)[0] if source_node.labels else None source_label = (
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None list(source_node.labels)[0] if source_node.labels else None
)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
if source_label and target_label: if source_label and target_label:
edges.append((source_label, target_label)) edges.append((source_label, target_label))
return edges
return edges
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
""" """
@@ -219,7 +225,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = node_id.strip('\"') label = node_id.strip('"')
properties = node_data properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction): async def _do_upsert(tx: AsyncManagedTransaction):
@@ -228,7 +234,9 @@ class Neo4JStorage(BaseGraphStorage):
SET n += $properties SET n += $properties
""" """
await tx.run(query, properties=properties) await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}") logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
try: try:
async with self._driver.session() as session: async with self._driver.session() as session:
@@ -236,13 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e: except Exception as e:
logger.error(f"Error during upsert: {str(e)}") logger.error(f"Error during upsert: {str(e)}")
raise raise
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
) )
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]): async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -251,8 +267,8 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = source_node_id.strip('\"') source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('\"') target_node_label = target_node_id.strip('"')
edge_properties = edge_data edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction): async def _do_upsert_edge(tx: AsyncManagedTransaction):
@@ -265,7 +281,9 @@ class Neo4JStorage(BaseGraphStorage):
RETURN r RETURN r
""" """
await tx.run(query, properties=edge_properties) await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}") logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
)
try: try:
async with self._driver.session() as session: async with self._driver.session() as session:
@@ -273,6 +291,6 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e: except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}") logger.error(f"Error during edge upsert: {str(e)}")
raise raise
async def _node2vec_embed(self): async def _node2vec_embed(self):
print ("Implemented but never called.") print("Implemented but never called.")

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import os import os
import importlib
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -24,18 +23,15 @@ from .storage import (
NanoVectorDBStorage, NanoVectorDBStorage,
NetworkXStorage, NetworkXStorage,
) )
from .kg.neo4j_impl import ( from .kg.neo4j_impl import Neo4JStorage
Neo4JStorage # future KG integrations
)
#future KG integrations
# from .kg.ArangoDB_impl import ( # from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage # GraphStorage as ArangoDBStorage
# ) # )
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc,
compute_mdhash_id, compute_mdhash_id,
@@ -56,16 +52,18 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
return asyncio.get_event_loop() return asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.info("Creating a new event loop in main thread.") logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
return loop return loop
@dataclass @dataclass
class LightRAG: class LightRAG:
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
@@ -75,8 +73,6 @@ class LightRAG:
current_log_level = logger.level current_log_level = logger.level
log_level: str = field(default=current_log_level) log_level: str = field(default=current_log_level)
# text chunking # text chunking
chunk_token_size: int = 1200 chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = 100
@@ -131,8 +127,10 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
#@TODO: should move all storage setup here to leverage initial start params attached to self. # @TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg] self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg
]
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
@@ -186,6 +184,7 @@ class LightRAG:
**self.llm_model_kwargs, **self.llm_model_kwargs,
) )
) )
def _get_storage_class(self) -> Type[BaseGraphStorage]: def _get_storage_class(self) -> Type[BaseGraphStorage]:
return { return {
"Neo4JStorage": Neo4JStorage, "Neo4JStorage": Neo4JStorage,
@@ -329,4 +328,4 @@ class LightRAG:
if storage_inst is None: if storage_inst is None:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -798,4 +798,4 @@ if __name__ == "__main__":
result = await gpt_4o_mini_complete("How are you?") result = await gpt_4o_mini_complete("How are you?")
print(result) print(result)
asyncio.run(main()) asyncio.run(main())

View File

@@ -466,7 +466,6 @@ async def _build_local_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
@@ -483,7 +482,7 @@ async def _build_local_query_context(
{**n, "entity_name": k["entity_name"], "rank": d} {**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst
) )
@@ -946,7 +945,6 @@ async def hybrid_query(
query_param, query_param,
) )
if hl_keywords: if hl_keywords:
high_level_context = await _build_global_query_context( high_level_context = await _build_global_query_context(
hl_keywords, hl_keywords,
@@ -957,7 +955,6 @@ async def hybrid_query(
query_param, query_param,
) )
context = combine_contexts(high_level_context, low_level_context) context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context: if query_param.only_need_context:
@@ -1026,9 +1023,11 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the entities # Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities) combined_entities = process_combine_contexts(hl_entities, ll_entities)
# Combine and deduplicate the relationships # Combine and deduplicate the relationships
combined_relationships = process_combine_contexts(hl_relationships, ll_relationships) combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources) combined_sources = process_combine_contexts(hl_sources, ll_sources)
@@ -1064,7 +1063,6 @@ async def naive_query(
chunks_ids = [r["id"] for r in results] chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids) chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size( maybe_trun_chunks = truncate_list_by_token_size(
chunks, chunks,
key=lambda x: x["content"], key=lambda x: x["content"],
@@ -1095,4 +1093,4 @@ async def naive_query(
.strip() .strip()
) )
return response return response

View File

@@ -233,8 +233,7 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
#@TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
from graspologic import embed from graspologic import embed

View File

@@ -9,7 +9,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union,List from typing import Any, Union, List
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
@@ -176,19 +176,20 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i] return list_data[:i]
return list_data return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str: def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerows(data) writer.writerows(data)
return output.getvalue() return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]: def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string) output = io.StringIO(csv_string)
reader = csv.reader(output) reader = csv.reader(output)
return [row for row in reader] return [row for row in reader]
def save_data_to_file(data, file_name): def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4) json.dump(data, f, ensure_ascii=False, indent=4)
@@ -253,13 +254,14 @@ def xml_to_json(xml_file):
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def process_combine_contexts(hl, ll): def process_combine_contexts(hl, ll):
header = None header = None
list_hl = csv_string_to_list(hl.strip()) list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip()) list_ll = csv_string_to_list(ll.strip())
if list_hl: if list_hl:
header=list_hl[0] header = list_hl[0]
list_hl = list_hl[1:] list_hl = list_hl[1:]
if list_ll: if list_ll:
header = list_ll[0] header = list_ll[0]
@@ -268,19 +270,17 @@ def process_combine_contexts(hl, ll):
return "" return ""
if list_hl: if list_hl:
list_hl = [','.join(item[1:]) for item in list_hl if item] list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll: if list_ll:
list_ll = [','.join(item[1:]) for item in list_ll if item] list_ll = [",".join(item[1:]) for item in list_ll if item]
combined_sources_set = set( combined_sources_set = set(filter(None, list_hl + list_ll))
filter(None, list_hl + list_ll)
)
combined_sources = [",\t".join(header)] combined_sources = [",\t".join(header)]
for i, item in enumerate(combined_sources_set, start=1): for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}") combined_sources.append(f"{i},\t{item}")
combined_sources = "\n".join(combined_sources) combined_sources = "\n".join(combined_sources)
return combined_sources return combined_sources

23
test.py
View File

@@ -1,11 +1,10 @@
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio # import nest_asyncio
# nest_asyncio.apply() # nest_asyncio.apply()
######### #########
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
@@ -15,7 +14,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
) )
@@ -23,13 +22,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -5,8 +5,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio # import nest_asyncio
# nest_asyncio.apply() # nest_asyncio.apply()
######### #########
WORKING_DIR = "./local_neo4jWorkDir" WORKING_DIR = "./local_neo4jWorkDir"
@@ -18,7 +18,7 @@ rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage", kg="Neo4JStorage",
log_level="INFO" log_level="INFO",
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
) )
@@ -26,13 +26,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)