Merge branch 'main' into code-cleaning

This commit is contained in:
zrguo
2025-02-16 19:41:05 +08:00
committed by GitHub
13 changed files with 822 additions and 183 deletions

View File

@@ -1 +1,63 @@
.env
# Python-related files and directories
__pycache__
.cache
# Virtual environment directories
*.venv
# Env
env/
*.env*
.env_example
# Distribution / build files
site
dist/
build/
.eggs/
*.egg-info/
*.tgz
*.tar.gz
# Exclude siles and folders
*.yml
.dockerignore
Dockerfile
Makefile
# Exclude other projects
/tests
/scripts
# Python version manager file
.python-version
# Reports
*.coverage/
*.log
log/
*.logfire
# Cache
.cache/
.mypy_cache
.pytest_cache
.ruff_cache
.gradio
.logfire
temp/
# MacOS-related files
.DS_Store
# VS Code settings (local configuration files)
.vscode
# file
TODO.md
# Exclude Git-related files
.git
.github
.gitignore
.pre-commit-config.yaml

79
.gitignore vendored
View File

@@ -1,26 +1,61 @@
__pycache__
*.egg-info
# Python-related files
__pycache__/
*.py[cod]
*.egg-info/
.eggs/
*.tgz
*.tar.gz
*.ini # Remove config.ini from repo
# Virtual Environment
.venv/
env/
venv/
*.env*
.env_example
# Build / Distribution
dist/
build/
site/
# Logs / Reports
*.log
*.logfire
*.coverage/
log/
# Caches
.cache/
.mypy_cache/
.pytest_cache/
.ruff_cache/
.gradio/
temp/
# IDE / Editor Files
.idea/
.vscode/
.vscode/settings.json
# Framework-specific files
local_neo4jWorkDir/
neo4jWorkDir/
# Data & Storage
inputs/
rag_storage/
examples/input/
examples/output/
# Miscellaneous
.DS_Store
TODO.md
ignore_this.txt
*.ignore.*
# Project-specific files
dickens/
book.txt
lightrag-dev/
.idea/
dist/
env/
local_neo4jWorkDir/
neo4jWorkDir/
ignore_this.txt
.venv/
*.ignore.*
.ruff_cache/
gui/
*.log
.vscode
inputs
rag_storage
.env
venv/
examples/input/
examples/output/
.DS_Store
#Remove config.ini from repo
*.ini

View File

@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance()
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
if CHROMADB_USE_LOCAL_PERSISTENT:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"local_path": CHROMADB_LOCAL_PATH,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
},
)
)
else:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)
# Run the initialization

View File

@@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
OracleVectorDBStorag Oracle
OracleVectorDBStorage Oracle
MongoVectorDBStorage MongoDB
```
* DOC_STATUS_STORAGEsupported implement-name

View File

@@ -2,7 +2,7 @@ import asyncio
from dataclasses import dataclass
from typing import Union
import numpy as np
from chromadb import HttpClient
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
@@ -49,31 +49,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
**user_collection_settings,
}
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
local_path = config.get("local_path", None)
if local_path:
self._client = PersistentClient(
path=local_path,
settings=Settings(
allow_reset=True,
anonymized_telemetry=False,
),
)
else:
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._collection = self._client.get_or_create_collection(
name=self.namespace,
@@ -144,7 +154,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
results = self._collection.query(
query_embeddings=embedding.tolist(),
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)

View File

@@ -4,6 +4,7 @@ import numpy as np
import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async
import asyncio
if not pm.is_installed("pymongo"):
pm.install("pymongo")
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from ..base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
config = configparser.ConfigParser()
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as KV {self.namespace}")
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
update_tasks = []
for mode, items in data.items():
for k, v in tqdm_async(items.items(), desc="Upserting"):
for k, v in items.items():
key = f"{mode}_{k}"
result = self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
data[mode][k]["_id"] = f"{mode}_{k}"
update_tasks.append(
self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
)
)
if result.upserted_id:
logger.debug(f"\nInserted new document with key: {key}")
data[mode][k]["_id"] = key
await asyncio.gather(*update_tasks)
else:
for k, v in tqdm_async(data.items(), desc="Upserting"):
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
update_tasks = []
for k, v in data.items():
data[k]["_id"] = k
update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
)
await asyncio.gather(*update_tasks)
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
res = {}
v = self._data.find_one({"_id": mode + "_" + id})
v = await self._data.find_one({"_id": mode + "_" + id})
if v:
res[id] = v
logger.debug(f"llm_response_cache find one by:{id}")
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as doc status {self.namespace}")
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
update_tasks = []
for k, v in data.items():
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
)
await asyncio.gather(*update_tasks)
async def drop(self) -> None:
"""Drop the collection"""
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
result = list(self._data.aggregate(pipeline))
cursor = self._data.aggregate(pipeline)
result = await cursor.to_list()
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
result = list(self._data.find({"status": status.value}))
cursor = self._data.find({"status": status.value})
result = await cursor.to_list()
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
global_config=global_config,
embedding_func=embedding_func,
)
self.client = AsyncIOMotorClient(
os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
self.db = self.client[
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
config.get("mongodb", "database", fallback="LightRAG"),
)
]
self.collection = self.db[
os.environ.get(
"MONGO_KG_COLLECTION",
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
)
]
)
self._collection_name = self.namespace
self.collection = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
#
# -------------------------------------------------------------------------
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
"""
Return a list of (target_id, relation) for direct edges from source_node_id.
Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
"""
pipeline = [
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
return None
edges = result[0].get("edges", [])
return [(e["target"], e["relation"]) for e in edges]
return [(source_node_id, e["target"]) for e in edges]
#
# -------------------------------------------------------------------------
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str):
"""
1) Remove nodes doc entirely.
1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
"""
# Remove inbound edges from all other docs
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
Placeholder for demonstration, raises NotImplementedError.
"""
raise NotImplementedError("Node embedding is not used in lightrag.")
#
# -------------------------------------------------------------------------
# QUERY
# -------------------------------------------------------------------------
#
async def get_all_labels(self) -> list[str]:
"""
Get all existing node _id in the database
Returns:
[id1, id2, ...] # Alphabetically sorted id list
"""
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{"$group": {"_id": "$_id"}}, # Group by _id
{"$sort": {"_id": 1}}, # Sort alphabetically
]
cursor = self.collection.aggregate(pipeline)
labels = []
async for doc in cursor:
labels.append(doc["_id"])
return labels
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
try:
if label == "*":
# Get all nodes and edges
async for node_doc in self.collection.find({}):
node_id = str(node_doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_doc.get("_id")],
properties={
k: v
for k, v in node_doc.items()
if k not in ["_id", "edges"]
},
)
)
seen_nodes.add(node_id)
# Process edges
for edge in node_doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
else:
# Verify if starting node exists
start_nodes = self.collection.find({"_id": label})
start_nodes_exist = await start_nodes.to_list(length=1)
if not start_nodes_exist:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# Use $graphLookup for traversal
pipeline = [
{
"$match": {"_id": label}
}, # Start with nodes having the specified label
{
"$graphLookup": {
"from": self._collection_name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_nodes",
}
},
]
async for doc in self.collection.aggregate(pipeline):
# Add the start node
node_id = str(doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[
doc.get(
"_id",
)
],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"edges",
"connected_nodes",
"depth",
]
},
)
)
seen_nodes.add(node_id)
# Add edges from start node
for edge in doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
# Add connected nodes and their edges
for connected in doc.get("connected_nodes", []):
node_id = str(connected["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[connected.get("_id")],
properties={
k: v
for k, v in connected.items()
if k not in ["_id", "edges", "depth"]
},
)
)
seen_nodes.add(node_id)
# Add edges from connected nodes
for edge in connected.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except PyMongoError as e:
logger.error(f"MongoDB query failed: {str(e)}")
return result
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
self._max_batch_size = self.global_config["embedding_batch_num"]
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
# Ensure vector index exists
self.create_vector_index(uri, database.name, self._collection_name)
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
"""Creates an Atlas Vector Search index."""
client = MongoClient(uri)
collection = client.get_database(database_name).get_collection(
self._collection_name
)
try:
search_index_model = SearchIndexModel(
definition={
"fields": [
{
"type": "vector",
"numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
"path": "vector",
"similarity": "cosine", # Options: euclidean, cosine, dotProduct
}
]
},
name="vector_knn_index",
type="vectorSearch",
)
collection.create_search_index(search_index_model)
logger.info("Vector index created successfully.")
except PyMongoError as _:
logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict]):
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
if not data:
logger.warning("You are inserting an empty data set to vector DB")
return []
list_data = [
{
"_id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
update_tasks = []
for doc in list_data:
update_tasks.append(
self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
)
await asyncio.gather(*update_tasks)
return list_data
async def query(self, query, top_k=5):
"""Queries the vector database using Atlas Vector Search."""
# Generate the embedding
embedding = await self.embedding_func([query])
# Convert numpy array to a list to ensure compatibility with MongoDB
query_vector = embedding[0].tolist()
# Define the aggregation pipeline with the converted query vector
pipeline = [
{
"$vectorSearch": {
"index": "vector_knn_index", # Ensure this matches the created index name
"path": "vector",
"queryVector": query_vector,
"numCandidates": 100, # Adjust for performance
"limit": top_k,
}
},
{"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
{"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
{"$project": {"vector": 0}},
]
# Execute the aggregation pipeline
cursor = self._data.aggregate(pipeline)
results = await cursor.to_list()
# Format and return the results
return [
{**doc, "id": doc["_id"], "distance": doc.get("score", None)}
for doc in results
]
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it."""
client = MongoClient(uri)
database = client.get_database(database_name)
collection_names = database.list_collection_names()
if collection_name not in collection_names:
database.create_collection(collection_name)
logger.info(f"Created collection: {collection_name}")
else:
logger.debug(f"Collection '{collection_name}' already exists.")

View File

@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
async def index_done_callback(self):
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
async def _label_exists(self, label: str) -> bool:
"""Check if a label exists in the Neo4j database."""
query = "CALL db.labels() YIELD label RETURN label"
try:
async with self._driver.session(database=self._DATABASE) as session:
result = await session.run(query)
labels = [record["label"] for record in await result.data()]
return label in labels
except Exception as e:
logger.error(f"Error checking label existence: {e}")
return False
async def _ensure_label(self, label: str) -> str:
"""Ensure a label exists by validating it."""
clean_label = label.strip('"')
if not await self._label_exists(clean_label):
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
return clean_label
async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
"""Get node by its label identifier.
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
"""
async with self._driver.session(database=self._DATABASE) as session:
entity_name_label = node_id.strip('"')
entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
"""Find edge between two nodes identified by their labels.
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
source_node_id (str): Label of the source node
target_node_id (str): Label of the target node
Returns:
list: List of all relationships/edges found
dict: Edge properties if found, with at least {"weight": 0.0}
None: If error occurs
"""
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
try:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
result = await session.run(query)
record = await result.single()
if record:
result = dict(record["edge_properties"])
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
return result
else:
return None
result = await session.run(query)
record = await result.single()
if record and "edge_properties" in record:
try:
result = dict(record["edge_properties"])
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"target_id": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
# Return default edge properties when no edge found
return {"weight": 0.0, "source_id": None, "target_id": None}
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('"')
label = await self._ensure_label(node_id)
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
source_label = await self._ensure_label(source_node_id)
target_label = await self._ensure_label(target_node_id)
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f"""
MATCH (source:`{source_node_label}`)
MATCH (source:`{source_label}`)
WITH source
MATCH (target:`{target_node_label}`)
MATCH (target:`{target_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += $properties
RETURN r
"""
await tx.run(query, properties=edge_properties)
result = await tx.run(query, properties=edge_properties)
record = await result.single()
logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
)
try:

View File

@@ -78,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
@@ -142,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
@@ -162,6 +164,7 @@ STORAGES = {
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",

View File

@@ -239,25 +239,65 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
already_keywords.extend(
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
)
# Handle the case where get_edge returns None or missing fields
if already_edge:
# Get weight with default 0.0 if missing
if "weight" in already_edge:
already_weights.append(already_edge["weight"])
else:
logger.warning(
f"Edge between {src_id} and {tgt_id} missing weight field"
)
already_weights.append(0.0)
# Get source_id with empty string default if missing or None
if "source_id" in already_edge and already_edge["source_id"] is not None:
already_source_ids.extend(
split_string_by_multi_markers(
already_edge["source_id"], [GRAPH_FIELD_SEP]
)
)
# Get description with empty string default if missing or None
if (
"description" in already_edge
and already_edge["description"] is not None
):
already_description.append(already_edge["description"])
# Get keywords with empty string default if missing or None
if "keywords" in already_edge and already_edge["keywords"] is not None:
already_keywords.extend(
split_string_by_multi_markers(
already_edge["keywords"], [GRAPH_FIELD_SEP]
)
)
# Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
sorted(
set(
[dp["description"] for dp in edges_data if dp.get("description")]
+ already_description
)
)
)
keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
sorted(
set(
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
+ already_keywords
)
)
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
set(
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
+ already_source_ids
)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from pydantic import BaseModel
from typing import Any
from typing import List, Dict, Any, Optional
class GPTKeywordExtractionFormat(BaseModel):
@@ -17,7 +17,7 @@ class KnowledgeGraphNode(BaseModel):
class KnowledgeGraphEdge(BaseModel):
id: str
type: str
type: Optional[str]
source: str # id of source node
target: str # id of target node
properties: dict[str, Any] # anything else goes here

View File

@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
<PropertyRow name={'Id'} value={edge.id} />
<PropertyRow name={'Type'} value={edge.type} />
{edge.type && <PropertyRow name={'Type'} value={edge.type} />}
<PropertyRow
name={'Source'}
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}

View File

@@ -24,7 +24,7 @@ const validateGraph = (graph: RawGraph) => {
}
for (const edge of graph.edges) {
if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) {
if (!edge.id || !edge.source || !edge.target) {
return false
}
}
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
if (source !== undefined && source !== undefined) {
const sourceNode = rawData.nodes[source]
const targetNode = rawData.nodes[target]
if (!sourceNode) {
console.error(`Source node ${edge.source} is undefined`)
continue
}
if (!targetNode) {
console.error(`Target node ${edge.target} is undefined`)
continue
}
sourceNode.degree += 1
targetNode.degree += 1
}
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
for (const rawEdge of rawGraph?.edges ?? []) {
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
label: rawEdge.type
label: rawEdge.type || undefined
})
}

View File

@@ -19,7 +19,7 @@ export type RawEdgeType = {
id: string
source: string
target: string
type: string
type?: string
properties: Record<string, any>
dynamicId: string