@@ -4,6 +4,7 @@ import numpy as np
import pipmaster as pm
import configparser
from tqdm . asyncio import tqdm as tqdm_async
import asyncio
if not pm . is_installed ( " pymongo " ) :
pm . install ( " pymongo " )
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
from typing import Any , List , Tuple , Union
from motor . motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo . operations import SearchIndexModel
from pymongo . errors import PyMongoError
from . . base import (
BaseGraphStorage ,
BaseKVStorage ,
BaseVectorStorage ,
DocProcessingStatus ,
DocStatus ,
DocStatusStorage ,
)
from . . namespace import NameSpace , is_namespace
from . . utils import logger
from . . types import KnowledgeGraph , KnowledgeGraphNode , KnowledgeGraphEdge
config = configparser . ConfigParser ( )
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
@dataclass
class MongoKVStorage ( BaseKVStorage ) :
def __post_init__ ( self ) :
client = MongoClien t(
os . environ . get (
" MONGO_URI " ,
config . get (
" mongodb " , " uri " , fallback = " mongodb://root:root@localhost:27017/ "
) ,
)
uri = os . environ . ge t(
" MONGO_URI " ,
config . get (
" mongodb " , " uri " , fallback = " mongodb://root:root@localhost:27017/ "
) ,
)
client = AsyncIOMotorClient ( uri )
database = client . get_database (
os . environ . get (
" MONGO_DATABASE " ,
config . get ( " mongodb " , " database " , fallback = " LightRAG " ) ,
)
)
self . _data = database . get_collection ( self . namespace )
logger . info ( f " Use MongoDB as KV { self . namespace } " )
self . _collection_name = self . namespace
self . _data = database . get_collection ( self . _collection_name )
logger . debug ( f " Use MongoDB as KV { self . _collection_name } " )
# Ensure collection exists
create_collection_if_not_exists ( uri , database . name , self . _collection_name )
async def get_by_id ( self , id : str ) - > Union [ dict [ str , Any ] , None ] :
return self . _data . find_one ( { " _id " : id } )
return await self . _data . find_one ( { " _id " : id } )
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
return list ( self . _data . find ( { " _id " : { " $in " : ids } } ) )
cursor = self . _data . find ( { " _id " : { " $in " : ids } } )
return await cursor . to_list ( )
async def filter_keys ( self , data : set [ str ] ) - > set [ str ] :
existing_ids = [
str ( x [ " _id " ] )
for x in self . _data . find ( { " _id " : { " $in " : list ( data ) } } , { " _id " : 1 } )
]
return set ( [ s for s in data if s not in existing_ids ] )
cursor = self . _data . find ( { " _id " : { " $in " : list ( data ) } } , { " _id " : 1 } )
existing_ids = { str ( x [ " _id " ] ) async for x in cursor }
return data - existing_ids
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
update_tasks = [ ]
for mode , items in data . items ( ) :
for k , v in tqdm_async ( items. items ( ) , desc = " Upserting " ) :
for k , v in items . items ( ) :
key = f " { mode } _ { k } "
result = self . _data . update_one (
{ " _id " : key } , { " $setOnInsert " : v } , upsert = True
data [ mode ] [ k ] [ " _id " ] = f " { mode } _ { k } "
update_tasks . append (
self . _data . update_one (
{ " _id " : key } , { " $setOnInsert " : v } , upsert = True
)
)
if result . upserted_id :
logger . debug ( f " \n Inserted new document with key: { key } " )
data [ mode ] [ k ] [ " _id " ] = key
await asyncio . gather ( * update_tasks )
else :
for k , v in tqdm_async ( data . items ( ) , desc = " Upserting " ) :
self . _data . update_one ( { " _id " : k } , { " $set " : v } , upsert = True )
update_tasks = [ ]
for k , v in data . items ( ) :
data [ k ] [ " _id " ] = k
update_tasks . append (
self . _data . update_one ( { " _id " : k } , { " $set " : v } , upsert = True )
)
await asyncio . gather ( * update_tasks )
async def get_by_mode_and_id ( self , mode : str , id : str ) - > Union [ dict , None ] :
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
res = { }
v = self . _data . find_one ( { " _id " : mode + " _ " + id } )
v = await self . _data . find_one ( { " _id " : mode + " _ " + id } )
if v :
res [ id ] = v
logger . debug ( f " llm_response_cache find one by: { id } " )
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
@dataclass
class MongoDocStatusStorage ( DocStatusStorage ) :
def __post_init__ ( self ) :
client = MongoClien t(
os . environ . get ( " MONGO_URI " , " mongodb://root:root@localhost:27017/ " )
uri = os . environ . ge t(
" MONGO_URI " ,
config . get (
" mongodb " , " uri " , fallback = " mongodb://root:root@localhost:27017/ "
) ,
)
database = client . get_database ( os . environ . get ( " MONGO_DATABASE " , " LightRAG " ) )
self . _data = database . get_collection ( self . namespac e)
logger . info ( f " Use MongoDB as doc status { self . namespace } " )
client = AsyncIOMotorClient ( uri )
database = client . get_databas e(
os . environ . get (
" MONGO_DATABASE " ,
config . get ( " mongodb " , " database " , fallback = " LightRAG " ) ,
)
)
self . _collection_name = self . namespace
self . _data = database . get_collection ( self . _collection_name )
logger . debug ( f " Use MongoDB as doc status { self . _collection_name } " )
# Ensure collection exists
create_collection_if_not_exists ( uri , database . name , self . _collection_name )
async def get_by_id ( self , id : str ) - > Union [ dict [ str , Any ] , None ] :
return self . _data . find_one ( { " _id " : id } )
return await self . _data . find_one ( { " _id " : id } )
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
return list ( self . _data . find ( { " _id " : { " $in " : ids } } ) )
cursor = self . _data . find ( { " _id " : { " $in " : ids } } )
return await cursor . to_list ( )
async def filter_keys ( self , data : set [ str ] ) - > set [ str ] :
existing_ids = [
str ( x [ " _id " ] )
for x in self . _data . find ( { " _id " : { " $in " : list ( data ) } } , { " _id " : 1 } )
]
return set ( [ s for s in data if s not in existing_ids ] )
cursor = self . _data . find ( { " _id " : { " $in " : list ( data ) } } , { " _id " : 1 } )
existing_ids = { str ( x [ " _id " ] ) async for x in cursor }
return data - existing_ids
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
update_tasks = [ ]
for k , v in data . items ( ) :
self . _data . update_one ( { " _id " : k } , { " $set " : v } , upsert = True )
data [ k ] [ " _id " ] = k
update_tasks . append (
self . _data . update_one ( { " _id " : k } , { " $set " : v } , upsert = True )
)
await asyncio . gather ( * update_tasks )
async def drop ( self ) - > None :
""" Drop the collection """
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_status_counts ( self ) - > dict [ str , int ] :
""" Get counts of documents in each status """
pipeline = [ { " $group " : { " _id " : " $status " , " count " : { " $sum " : 1 } } } ]
result = list ( self . _data . aggregate ( pipeline ) )
cursor = self . _data . aggregate ( pipeline )
result = await cursor . to_list ( )
counts = { }
for doc in result :
counts [ doc [ " _id " ] ] = doc [ " count " ]
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
self , status : DocStatus
) - > dict [ str , DocProcessingStatus ] :
""" Get all documents by status """
result = list ( self . _data . find ( { " status " : status . value } ) )
cursor = self . _data . find ( { " status " : status . value } )
result = await cursor . to_list ( )
return {
doc [ " _id " ] : DocProcessingStatus (
content = doc [ " content " ] ,
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
global_config = global_config ,
embedding_func = embedding_func ,
)
self . client = AsyncIOMotorClien t(
os . environ . get (
" MONGO_URI " ,
config . get (
" mongodb " , " uri " , fallback = " mongodb://root:root@localhost:27017/ "
) ,
)
uri = os . environ . ge t(
" MONGO_URI " ,
config . get (
" mongodb " , " uri " , fallback = " mongodb://root:root@localhost:27017/ "
) ,
)
self . db = self . client [
client = AsyncIOMotorClient ( uri )
database = client . get_database (
os . environ . get (
" MONGO_DATABASE " ,
mongo_database = config. get ( " mongodb " , " database " , fallback = " LightRAG " ) ,
config . get ( " mongodb " , " database " , fallback = " LightRAG " ) ,
)
]
self . collection = self . db [
os . environ . get (
" MONGO_KG_COLLECTION " ,
config . getboolean ( " mongodb " , " kg_collection " , fallback = " MDB_KG " ) ,
)
]
)
self . _collection_name = self . namespace
self . collection = database . get_collection ( self . _collection_name )
logger . debug ( f " Use MongoDB as KG { self . _collection_name } " )
# Ensure collection exists
create_collection_if_not_exists ( uri , database . name , self . _collection_name )
#
# -------------------------------------------------------------------------
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
self , source_node_id : str
) - > Union [ List [ Tuple [ str , str ] ] , None ] :
"""
Return a list of (target_id, relation ) for direct edges from source_node_id.
Return a list of (source_id, target_id ) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
"""
pipeline = [
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
return None
edges = result [ 0 ] . get ( " edges " , [ ] )
return [ ( e [ " target " ] , e [ " relation " ] ) for e in edges ]
return [ ( source_node_id , e [ " target " ] ) for e in edges ]
#
# -------------------------------------------------------------------------
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def delete_node ( self , node_id : str ) :
"""
1) Remove node’ s doc entirely.
1) Remove node' s doc entirely.
2) Remove inbound edges from any doc that references node_id.
"""
# Remove inbound edges from all other docs
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
Placeholder for demonstration, raises NotImplementedError.
"""
raise NotImplementedError ( " Node embedding is not used in lightrag. " )
#
# -------------------------------------------------------------------------
# QUERY
# -------------------------------------------------------------------------
#
async def get_all_labels ( self ) - > list [ str ] :
"""
Get all existing node _id in the database
Returns:
[id1, id2, ...] # Alphabetically sorted id list
"""
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{ " $group " : { " _id " : " $_id " } } , # Group by _id
{ " $sort " : { " _id " : 1 } } , # Sort alphabetically
]
cursor = self . collection . aggregate ( pipeline )
labels = [ ]
async for doc in cursor :
labels . append ( doc [ " _id " ] )
return labels
async def get_knowledge_graph (
self , node_label : str , max_depth : int = 5
) - > KnowledgeGraph :
"""
Get complete connected subgraph for specified node (including the starting node itself)
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
label = node_label
result = KnowledgeGraph ( )
seen_nodes = set ( )
seen_edges = set ( )
try :
if label == " * " :
# Get all nodes and edges
async for node_doc in self . collection . find ( { } ) :
node_id = str ( node_doc [ " _id " ] )
if node_id not in seen_nodes :
result . nodes . append (
KnowledgeGraphNode (
id = node_id ,
labels = [ node_doc . get ( " _id " ) ] ,
properties = {
k : v
for k , v in node_doc . items ( )
if k not in [ " _id " , " edges " ]
} ,
)
)
seen_nodes . add ( node_id )
# Process edges
for edge in node_doc . get ( " edges " , [ ] ) :
edge_id = f " { node_id } - { edge [ ' target ' ] } "
if edge_id not in seen_edges :
result . edges . append (
KnowledgeGraphEdge (
id = edge_id ,
type = edge . get ( " relation " , " " ) ,
source = node_id ,
target = edge [ " target " ] ,
properties = {
k : v
for k , v in edge . items ( )
if k not in [ " target " , " relation " ]
} ,
)
)
seen_edges . add ( edge_id )
else :
# Verify if starting node exists
start_nodes = self . collection . find ( { " _id " : label } )
start_nodes_exist = await start_nodes . to_list ( length = 1 )
if not start_nodes_exist :
logger . warning ( f " Starting node with label { label } does not exist! " )
return result
# Use $graphLookup for traversal
pipeline = [
{
" $match " : { " _id " : label }
} , # Start with nodes having the specified label
{
" $graphLookup " : {
" from " : self . _collection_name ,
" startWith " : " $edges.target " ,
" connectFromField " : " edges.target " ,
" connectToField " : " _id " ,
" maxDepth " : max_depth ,
" depthField " : " depth " ,
" as " : " connected_nodes " ,
}
} ,
]
async for doc in self . collection . aggregate ( pipeline ) :
# Add the start node
node_id = str ( doc [ " _id " ] )
if node_id not in seen_nodes :
result . nodes . append (
KnowledgeGraphNode (
id = node_id ,
labels = [
doc . get (
" _id " ,
)
] ,
properties = {
k : v
for k , v in doc . items ( )
if k
not in [
" _id " ,
" edges " ,
" connected_nodes " ,
" depth " ,
]
} ,
)
)
seen_nodes . add ( node_id )
# Add edges from start node
for edge in doc . get ( " edges " , [ ] ) :
edge_id = f " { node_id } - { edge [ ' target ' ] } "
if edge_id not in seen_edges :
result . edges . append (
KnowledgeGraphEdge (
id = edge_id ,
type = edge . get ( " relation " , " " ) ,
source = node_id ,
target = edge [ " target " ] ,
properties = {
k : v
for k , v in edge . items ( )
if k not in [ " target " , " relation " ]
} ,
)
)
seen_edges . add ( edge_id )
# Add connected nodes and their edges
for connected in doc . get ( " connected_nodes " , [ ] ) :
node_id = str ( connected [ " _id " ] )
if node_id not in seen_nodes :
result . nodes . append (
KnowledgeGraphNode (
id = node_id ,
labels = [ connected . get ( " _id " ) ] ,
properties = {
k : v
for k , v in connected . items ( )
if k not in [ " _id " , " edges " , " depth " ]
} ,
)
)
seen_nodes . add ( node_id )
# Add edges from connected nodes
for edge in connected . get ( " edges " , [ ] ) :
edge_id = f " { node_id } - { edge [ ' target ' ] } "
if edge_id not in seen_edges :
result . edges . append (
KnowledgeGraphEdge (
id = edge_id ,
type = edge . get ( " relation " , " " ) ,
source = node_id ,
target = edge [ " target " ] ,
properties = {
k : v
for k , v in edge . items ( )
if k not in [ " target " , " relation " ]
} ,
)
)
seen_edges . add ( edge_id )
logger . info (
f " Subgraph query successful | Node count: { len ( result . nodes ) } | Edge count: { len ( result . edges ) } "
)
except PyMongoError as e :
logger . error ( f " MongoDB query failed: { str ( e ) } " )
return result
@dataclass
class MongoVectorDBStorage ( BaseVectorStorage ) :
cosine_better_than_threshold : float = None
def __post_init__ ( self ) :
kwargs = self . global_config . get ( " vector_db_storage_cls_kwargs " , { } )
cosine_threshold = kwargs . get ( " cosine_better_than_threshold " )
if cosine_threshold is None :
raise ValueError (
" cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs "
)
self . cosine_better_than_threshold = cosine_threshold
uri = os . environ . get (
" MONGO_URI " ,
config . get (
" mongodb " , " uri " , fallback = " mongodb://root:root@localhost:27017/ "
) ,
)
client = AsyncIOMotorClient ( uri )
database = client . get_database (
os . environ . get (
" MONGO_DATABASE " ,
config . get ( " mongodb " , " database " , fallback = " LightRAG " ) ,
)
)
self . _collection_name = self . namespace
self . _data = database . get_collection ( self . _collection_name )
self . _max_batch_size = self . global_config [ " embedding_batch_num " ]
logger . debug ( f " Use MongoDB as VDB { self . _collection_name } " )
# Ensure collection exists
create_collection_if_not_exists ( uri , database . name , self . _collection_name )
# Ensure vector index exists
self . create_vector_index ( uri , database . name , self . _collection_name )
def create_vector_index ( self , uri : str , database_name : str , collection_name : str ) :
""" Creates an Atlas Vector Search index. """
client = MongoClient ( uri )
collection = client . get_database ( database_name ) . get_collection (
self . _collection_name
)
try :
search_index_model = SearchIndexModel (
definition = {
" fields " : [
{
" type " : " vector " ,
" numDimensions " : self . embedding_func . embedding_dim , # Ensure correct dimensions
" path " : " vector " ,
" similarity " : " cosine " , # Options: euclidean, cosine, dotProduct
}
]
} ,
name = " vector_knn_index " ,
type = " vectorSearch " ,
)
collection . create_search_index ( search_index_model )
logger . info ( " Vector index created successfully. " )
except PyMongoError as _ :
logger . debug ( " vector index already exist " )
async def upsert ( self , data : dict [ str , dict ] ) :
logger . debug ( f " Inserting { len ( data ) } vectors to { self . namespace } " )
if not data :
logger . warning ( " You are inserting an empty data set to vector DB " )
return [ ]
list_data = [
{
" _id " : k ,
* * { k1 : v1 for k1 , v1 in v . items ( ) if k1 in self . meta_fields } ,
}
for k , v in data . items ( )
]
contents = [ v [ " content " ] for v in data . values ( ) ]
batches = [
contents [ i : i + self . _max_batch_size ]
for i in range ( 0 , len ( contents ) , self . _max_batch_size )
]
async def wrapped_task ( batch ) :
result = await self . embedding_func ( batch )
pbar . update ( 1 )
return result
embedding_tasks = [ wrapped_task ( batch ) for batch in batches ]
pbar = tqdm_async (
total = len ( embedding_tasks ) , desc = " Generating embeddings " , unit = " batch "
)
embeddings_list = await asyncio . gather ( * embedding_tasks )
embeddings = np . concatenate ( embeddings_list )
for i , d in enumerate ( list_data ) :
d [ " vector " ] = np . array ( embeddings [ i ] , dtype = np . float32 ) . tolist ( )
update_tasks = [ ]
for doc in list_data :
update_tasks . append (
self . _data . update_one ( { " _id " : doc [ " _id " ] } , { " $set " : doc } , upsert = True )
)
await asyncio . gather ( * update_tasks )
return list_data
async def query ( self , query , top_k = 5 ) :
""" Queries the vector database using Atlas Vector Search. """
# Generate the embedding
embedding = await self . embedding_func ( [ query ] )
# Convert numpy array to a list to ensure compatibility with MongoDB
query_vector = embedding [ 0 ] . tolist ( )
# Define the aggregation pipeline with the converted query vector
pipeline = [
{
" $vectorSearch " : {
" index " : " vector_knn_index " , # Ensure this matches the created index name
" path " : " vector " ,
" queryVector " : query_vector ,
" numCandidates " : 100 , # Adjust for performance
" limit " : top_k ,
}
} ,
{ " $addFields " : { " score " : { " $meta " : " vectorSearchScore " } } } ,
{ " $match " : { " score " : { " $gte " : self . cosine_better_than_threshold } } } ,
{ " $project " : { " vector " : 0 } } ,
]
# Execute the aggregation pipeline
cursor = self . _data . aggregate ( pipeline )
results = await cursor . to_list ( )
# Format and return the results
return [
{ * * doc , " id " : doc [ " _id " ] , " distance " : doc . get ( " score " , None ) }
for doc in results
]
def create_collection_if_not_exists ( uri : str , database_name : str , collection_name : str ) :
""" Check if the collection exists. if not, create it. """
client = MongoClient ( uri )
database = client . get_database ( database_name )
collection_names = database . list_collection_names ( )
if collection_name not in collection_names :
database . create_collection ( collection_name )
logger . info ( f " Created collection: { collection_name } " )
else :
logger . debug ( f " Collection ' { collection_name } ' already exists. " )