Merge pull request #781 from ArnoChenFx/enhance-mongodb-backends
Add MongoDB VectorDB Support and Optimize Existing MongoDB Implementations
This commit is contained in:
@@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB
|
|||||||
PGVectorStorage Postgres
|
PGVectorStorage Postgres
|
||||||
FaissVectorDBStorage Faiss
|
FaissVectorDBStorage Faiss
|
||||||
QdrantVectorDBStorage Qdrant
|
QdrantVectorDBStorage Qdrant
|
||||||
OracleVectorDBStorag Oracle
|
OracleVectorDBStorage Oracle
|
||||||
|
MongoVectorDBStorage MongoDB
|
||||||
```
|
```
|
||||||
|
|
||||||
* DOC_STATUS_STORAGE:supported implement-name
|
* DOC_STATUS_STORAGE:supported implement-name
|
||||||
|
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import configparser
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
import asyncio
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
pm.install("pymongo")
|
pm.install("pymongo")
|
||||||
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
|
|||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
|
BaseVectorStorage,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
uri = os.environ.get(
|
||||||
os.environ.get(
|
|
||||||
"MONGO_URI",
|
"MONGO_URI",
|
||||||
config.get(
|
config.get(
|
||||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
client = AsyncIOMotorClient(uri)
|
||||||
database = client.get_database(
|
database = client.get_database(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._data = database.get_collection(self.namespace)
|
|
||||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
self._collection_name = self.namespace
|
||||||
|
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||||
str(x["_id"])
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
return data - existing_ids
|
||||||
]
|
|
||||||
return set([s for s in data if s not in existing_ids])
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
|
update_tasks = []
|
||||||
for mode, items in data.items():
|
for mode, items in data.items():
|
||||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
for k, v in items.items():
|
||||||
key = f"{mode}_{k}"
|
key = f"{mode}_{k}"
|
||||||
result = self._data.update_one(
|
data[mode][k]["_id"] = f"{mode}_{k}"
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one(
|
||||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||||
)
|
)
|
||||||
if result.upserted_id:
|
)
|
||||||
logger.debug(f"\nInserted new document with key: {key}")
|
await asyncio.gather(*update_tasks)
|
||||||
data[mode][k]["_id"] = key
|
|
||||||
else:
|
else:
|
||||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
update_tasks = []
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
for k, v in data.items():
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
res = {}
|
res = {}
|
||||||
v = self._data.find_one({"_id": mode + "_" + id})
|
v = await self._data.find_one({"_id": mode + "_" + id})
|
||||||
if v:
|
if v:
|
||||||
res[id] = v
|
res[id] = v
|
||||||
logger.debug(f"llm_response_cache find one by:{id}")
|
logger.debug(f"llm_response_cache find one by:{id}")
|
||||||
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
uri = os.environ.get(
|
||||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
client = AsyncIOMotorClient(uri)
|
||||||
self._data = database.get_collection(self.namespace)
|
database = client.get_database(
|
||||||
logger.info(f"Use MongoDB as doc status {self.namespace}")
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||||
str(x["_id"])
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
return data - existing_ids
|
||||||
]
|
|
||||||
return set([s for s in data if s not in existing_ids])
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
update_tasks = []
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the collection"""
|
"""Drop the collection"""
|
||||||
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
||||||
result = list(self._data.aggregate(pipeline))
|
cursor = self._data.aggregate(pipeline)
|
||||||
|
result = await cursor.to_list()
|
||||||
counts = {}
|
counts = {}
|
||||||
for doc in result:
|
for doc in result:
|
||||||
counts[doc["_id"]] = doc["count"]
|
counts[doc["_id"]] = doc["count"]
|
||||||
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents by status"""
|
"""Get all documents by status"""
|
||||||
result = list(self._data.find({"status": status.value}))
|
cursor = self._data.find({"status": status.value})
|
||||||
|
result = await cursor.to_list()
|
||||||
return {
|
return {
|
||||||
doc["_id"]: DocProcessingStatus(
|
doc["_id"]: DocProcessingStatus(
|
||||||
content=doc["content"],
|
content=doc["content"],
|
||||||
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
global_config=global_config,
|
global_config=global_config,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
self.client = AsyncIOMotorClient(
|
uri = os.environ.get(
|
||||||
os.environ.get(
|
|
||||||
"MONGO_URI",
|
"MONGO_URI",
|
||||||
config.get(
|
config.get(
|
||||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
client = AsyncIOMotorClient(uri)
|
||||||
self.db = self.client[
|
database = client.get_database(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
]
|
|
||||||
self.collection = self.db[
|
|
||||||
os.environ.get(
|
|
||||||
"MONGO_KG_COLLECTION",
|
|
||||||
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self.collection = database.get_collection(self._collection_name)
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
self, source_node_id: str
|
self, source_node_id: str
|
||||||
) -> Union[List[Tuple[str, str]], None]:
|
) -> Union[List[Tuple[str, str]], None]:
|
||||||
"""
|
"""
|
||||||
Return a list of (target_id, relation) for direct edges from source_node_id.
|
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||||
"""
|
"""
|
||||||
pipeline = [
|
pipeline = [
|
||||||
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
edges = result[0].get("edges", [])
|
edges = result[0].get("edges", [])
|
||||||
return [(e["target"], e["relation"]) for e in edges]
|
return [(source_node_id, e["target"]) for e in edges]
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str):
|
||||||
"""
|
"""
|
||||||
1) Remove node’s doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound edges from any doc that references node_id.
|
||||||
"""
|
"""
|
||||||
# Remove inbound edges from all other docs
|
# Remove inbound edges from all other docs
|
||||||
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
Placeholder for demonstration, raises NotImplementedError.
|
Placeholder for demonstration, raises NotImplementedError.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
|
||||||
|
#
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# QUERY
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all existing node _id in the database
|
||||||
|
Returns:
|
||||||
|
[id1, id2, ...] # Alphabetically sorted id list
|
||||||
|
"""
|
||||||
|
# Use MongoDB's distinct and aggregation to get all unique labels
|
||||||
|
pipeline = [
|
||||||
|
{"$group": {"_id": "$_id"}}, # Group by _id
|
||||||
|
{"$sort": {"_id": 1}}, # Sort alphabetically
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = self.collection.aggregate(pipeline)
|
||||||
|
labels = []
|
||||||
|
async for doc in cursor:
|
||||||
|
labels.append(doc["_id"])
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_label: Label of the nodes to start from
|
||||||
|
max_depth: Maximum depth of traversal (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||||
|
"""
|
||||||
|
label = node_label
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if label == "*":
|
||||||
|
# Get all nodes and edges
|
||||||
|
async for node_doc in self.collection.find({}):
|
||||||
|
node_id = str(node_doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_doc.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in node_doc.items()
|
||||||
|
if k not in ["_id", "edges"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Process edges
|
||||||
|
for edge in node_doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
else:
|
||||||
|
# Verify if starting node exists
|
||||||
|
start_nodes = self.collection.find({"_id": label})
|
||||||
|
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||||
|
if not start_nodes_exist:
|
||||||
|
logger.warning(f"Starting node with label {label} does not exist!")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Use $graphLookup for traversal
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {"_id": label}
|
||||||
|
}, # Start with nodes having the specified label
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": self._collection_name,
|
||||||
|
"startWith": "$edges.target",
|
||||||
|
"connectFromField": "edges.target",
|
||||||
|
"connectToField": "_id",
|
||||||
|
"maxDepth": max_depth,
|
||||||
|
"depthField": "depth",
|
||||||
|
"as": "connected_nodes",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async for doc in self.collection.aggregate(pipeline):
|
||||||
|
# Add the start node
|
||||||
|
node_id = str(doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[
|
||||||
|
doc.get(
|
||||||
|
"_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in doc.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"edges",
|
||||||
|
"connected_nodes",
|
||||||
|
"depth",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from start node
|
||||||
|
for edge in doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
# Add connected nodes and their edges
|
||||||
|
for connected in doc.get("connected_nodes", []):
|
||||||
|
node_id = str(connected["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[connected.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in connected.items()
|
||||||
|
if k not in ["_id", "edges", "depth"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from connected nodes
|
||||||
|
for edge in connected.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"MongoDB query failed: {str(e)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
|
uri = os.environ.get(
|
||||||
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
client = AsyncIOMotorClient(uri)
|
||||||
|
database = client.get_database(
|
||||||
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
|
# Ensure vector index exists
|
||||||
|
self.create_vector_index(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
|
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
||||||
|
"""Creates an Atlas Vector Search index."""
|
||||||
|
client = MongoClient(uri)
|
||||||
|
collection = client.get_database(database_name).get_collection(
|
||||||
|
self._collection_name
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_index_model = SearchIndexModel(
|
||||||
|
definition={
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"type": "vector",
|
||||||
|
"numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
|
||||||
|
"path": "vector",
|
||||||
|
"similarity": "cosine", # Options: euclidean, cosine, dotProduct
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
name="vector_knn_index",
|
||||||
|
type="vectorSearch",
|
||||||
|
)
|
||||||
|
|
||||||
|
collection.create_search_index(search_index_model)
|
||||||
|
logger.info("Vector index created successfully.")
|
||||||
|
|
||||||
|
except PyMongoError as _:
|
||||||
|
logger.debug("vector index already exist")
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
|
if not data:
|
||||||
|
logger.warning("You are inserting an empty data set to vector DB")
|
||||||
|
return []
|
||||||
|
|
||||||
|
list_data = [
|
||||||
|
{
|
||||||
|
"_id": k,
|
||||||
|
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||||
|
}
|
||||||
|
for k, v in data.items()
|
||||||
|
]
|
||||||
|
contents = [v["content"] for v in data.values()]
|
||||||
|
batches = [
|
||||||
|
contents[i : i + self._max_batch_size]
|
||||||
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def wrapped_task(batch):
|
||||||
|
result = await self.embedding_func(batch)
|
||||||
|
pbar.update(1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||||
|
pbar = tqdm_async(
|
||||||
|
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||||
|
|
||||||
|
embeddings = np.concatenate(embeddings_list)
|
||||||
|
for i, d in enumerate(list_data):
|
||||||
|
d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
|
||||||
|
|
||||||
|
update_tasks = []
|
||||||
|
for doc in list_data:
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
|
return list_data
|
||||||
|
|
||||||
|
async def query(self, query, top_k=5):
|
||||||
|
"""Queries the vector database using Atlas Vector Search."""
|
||||||
|
# Generate the embedding
|
||||||
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
|
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||||
|
query_vector = embedding[0].tolist()
|
||||||
|
|
||||||
|
# Define the aggregation pipeline with the converted query vector
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$vectorSearch": {
|
||||||
|
"index": "vector_knn_index", # Ensure this matches the created index name
|
||||||
|
"path": "vector",
|
||||||
|
"queryVector": query_vector,
|
||||||
|
"numCandidates": 100, # Adjust for performance
|
||||||
|
"limit": top_k,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
|
||||||
|
{"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
|
||||||
|
{"$project": {"vector": 0}},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the aggregation pipeline
|
||||||
|
cursor = self._data.aggregate(pipeline)
|
||||||
|
results = await cursor.to_list()
|
||||||
|
|
||||||
|
# Format and return the results
|
||||||
|
return [
|
||||||
|
{**doc, "id": doc["_id"], "distance": doc.get("score", None)}
|
||||||
|
for doc in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||||
|
"""Check if the collection exists. if not, create it."""
|
||||||
|
client = MongoClient(uri)
|
||||||
|
database = client.get_database(database_name)
|
||||||
|
|
||||||
|
collection_names = database.list_collection_names()
|
||||||
|
|
||||||
|
if collection_name not in collection_names:
|
||||||
|
database.create_collection(collection_name)
|
||||||
|
logger.info(f"Created collection: {collection_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Collection '{collection_name}' already exists.")
|
||||||
|
@@ -76,6 +76,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|||||||
"FaissVectorDBStorage",
|
"FaissVectorDBStorage",
|
||||||
"QdrantVectorDBStorage",
|
"QdrantVectorDBStorage",
|
||||||
"OracleVectorDBStorage",
|
"OracleVectorDBStorage",
|
||||||
|
"MongoVectorDBStorage",
|
||||||
],
|
],
|
||||||
"required_methods": ["query", "upsert"],
|
"required_methods": ["query", "upsert"],
|
||||||
},
|
},
|
||||||
@@ -140,6 +141,7 @@ STORAGE_ENV_REQUIREMENTS = {
|
|||||||
"ORACLE_PASSWORD",
|
"ORACLE_PASSWORD",
|
||||||
"ORACLE_CONFIG_DIR",
|
"ORACLE_CONFIG_DIR",
|
||||||
],
|
],
|
||||||
|
"MongoVectorDBStorage": [],
|
||||||
# Document Status Storage Implementations
|
# Document Status Storage Implementations
|
||||||
"JsonDocStatusStorage": [],
|
"JsonDocStatusStorage": [],
|
||||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
@@ -160,6 +162,7 @@ STORAGES = {
|
|||||||
"MongoKVStorage": ".kg.mongo_impl",
|
"MongoKVStorage": ".kg.mongo_impl",
|
||||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||||
"MongoGraphStorage": ".kg.mongo_impl",
|
"MongoGraphStorage": ".kg.mongo_impl",
|
||||||
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||||
"RedisKVStorage": ".kg.redis_impl",
|
"RedisKVStorage": ".kg.redis_impl",
|
||||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||||
"TiDBKVStorage": ".kg.tidb_impl",
|
"TiDBKVStorage": ".kg.tidb_impl",
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordExtractionFormat(BaseModel):
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
@@ -15,7 +15,7 @@ class KnowledgeGraphNode(BaseModel):
|
|||||||
|
|
||||||
class KnowledgeGraphEdge(BaseModel):
|
class KnowledgeGraphEdge(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
type: str
|
type: Optional[str]
|
||||||
source: str # id of source node
|
source: str # id of source node
|
||||||
target: str # id of target node
|
target: str # id of target node
|
||||||
properties: Dict[str, Any] # anything else goes here
|
properties: Dict[str, Any] # anything else goes here
|
||||||
|
@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
|
|||||||
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
||||||
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
||||||
<PropertyRow name={'Id'} value={edge.id} />
|
<PropertyRow name={'Id'} value={edge.id} />
|
||||||
<PropertyRow name={'Type'} value={edge.type} />
|
{edge.type && <PropertyRow name={'Type'} value={edge.type} />}
|
||||||
<PropertyRow
|
<PropertyRow
|
||||||
name={'Source'}
|
name={'Source'}
|
||||||
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
||||||
|
@@ -24,7 +24,7 @@ const validateGraph = (graph: RawGraph) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (const edge of graph.edges) {
|
for (const edge of graph.edges) {
|
||||||
if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) {
|
if (!edge.id || !edge.source || !edge.target) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
|
|||||||
if (source !== undefined && source !== undefined) {
|
if (source !== undefined && source !== undefined) {
|
||||||
const sourceNode = rawData.nodes[source]
|
const sourceNode = rawData.nodes[source]
|
||||||
const targetNode = rawData.nodes[target]
|
const targetNode = rawData.nodes[target]
|
||||||
|
if (!sourceNode) {
|
||||||
|
console.error(`Source node ${edge.source} is undefined`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (!targetNode) {
|
||||||
|
console.error(`Target node ${edge.target} is undefined`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
sourceNode.degree += 1
|
sourceNode.degree += 1
|
||||||
targetNode.degree += 1
|
targetNode.degree += 1
|
||||||
}
|
}
|
||||||
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
|||||||
|
|
||||||
for (const rawEdge of rawGraph?.edges ?? []) {
|
for (const rawEdge of rawGraph?.edges ?? []) {
|
||||||
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
||||||
label: rawEdge.type
|
label: rawEdge.type || undefined
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -19,7 +19,7 @@ export type RawEdgeType = {
|
|||||||
id: string
|
id: string
|
||||||
source: string
|
source: string
|
||||||
target: string
|
target: string
|
||||||
type: string
|
type?: string
|
||||||
properties: Record<string, any>
|
properties: Record<string, any>
|
||||||
|
|
||||||
dynamicId: string
|
dynamicId: string
|
||||||
|
Reference in New Issue
Block a user