Merge branch 'main' into code-cleaning
This commit is contained in:
@@ -1 +1,63 @@
|
||||
.env
|
||||
# Python-related files and directories
|
||||
__pycache__
|
||||
.cache
|
||||
|
||||
# Virtual environment directories
|
||||
*.venv
|
||||
|
||||
# Env
|
||||
env/
|
||||
*.env*
|
||||
.env_example
|
||||
|
||||
# Distribution / build files
|
||||
site
|
||||
dist/
|
||||
build/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
*.tgz
|
||||
*.tar.gz
|
||||
|
||||
# Exclude siles and folders
|
||||
*.yml
|
||||
.dockerignore
|
||||
Dockerfile
|
||||
Makefile
|
||||
|
||||
# Exclude other projects
|
||||
/tests
|
||||
/scripts
|
||||
|
||||
# Python version manager file
|
||||
.python-version
|
||||
|
||||
# Reports
|
||||
*.coverage/
|
||||
*.log
|
||||
log/
|
||||
*.logfire
|
||||
|
||||
# Cache
|
||||
.cache/
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
.gradio
|
||||
.logfire
|
||||
temp/
|
||||
|
||||
# MacOS-related files
|
||||
.DS_Store
|
||||
|
||||
# VS Code settings (local configuration files)
|
||||
.vscode
|
||||
|
||||
# file
|
||||
TODO.md
|
||||
|
||||
# Exclude Git-related files
|
||||
.git
|
||||
.github
|
||||
.gitignore
|
||||
.pre-commit-config.yaml
|
||||
|
79
.gitignore
vendored
79
.gitignore
vendored
@@ -1,26 +1,61 @@
|
||||
__pycache__
|
||||
*.egg-info
|
||||
# Python-related files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
.eggs/
|
||||
*.tgz
|
||||
*.tar.gz
|
||||
*.ini # Remove config.ini from repo
|
||||
|
||||
# Virtual Environment
|
||||
.venv/
|
||||
env/
|
||||
venv/
|
||||
*.env*
|
||||
.env_example
|
||||
|
||||
# Build / Distribution
|
||||
dist/
|
||||
build/
|
||||
site/
|
||||
|
||||
# Logs / Reports
|
||||
*.log
|
||||
*.logfire
|
||||
*.coverage/
|
||||
log/
|
||||
|
||||
# Caches
|
||||
.cache/
|
||||
.mypy_cache/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.gradio/
|
||||
temp/
|
||||
|
||||
# IDE / Editor Files
|
||||
.idea/
|
||||
.vscode/
|
||||
.vscode/settings.json
|
||||
|
||||
# Framework-specific files
|
||||
local_neo4jWorkDir/
|
||||
neo4jWorkDir/
|
||||
|
||||
# Data & Storage
|
||||
inputs/
|
||||
rag_storage/
|
||||
examples/input/
|
||||
examples/output/
|
||||
|
||||
# Miscellaneous
|
||||
.DS_Store
|
||||
TODO.md
|
||||
ignore_this.txt
|
||||
*.ignore.*
|
||||
|
||||
# Project-specific files
|
||||
dickens/
|
||||
book.txt
|
||||
lightrag-dev/
|
||||
.idea/
|
||||
dist/
|
||||
env/
|
||||
local_neo4jWorkDir/
|
||||
neo4jWorkDir/
|
||||
ignore_this.txt
|
||||
.venv/
|
||||
*.ignore.*
|
||||
.ruff_cache/
|
||||
gui/
|
||||
*.log
|
||||
.vscode
|
||||
inputs
|
||||
rag_storage
|
||||
.env
|
||||
venv/
|
||||
examples/input/
|
||||
examples/output/
|
||||
.DS_Store
|
||||
#Remove config.ini from repo
|
||||
*.ini
|
||||
|
@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
# ChromaDB Configuration
|
||||
CHROMADB_USE_LOCAL_PERSISTENT = False
|
||||
# Local PersistentClient Configuration
|
||||
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
|
||||
# Remote HttpClient Configuration
|
||||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
||||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
||||
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
||||
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
|
||||
|
||||
async def initialize_rag():
|
||||
embedding_func_instance = await create_embedding_function_instance()
|
||||
|
||||
return LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=embedding_func_instance,
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
log_level="DEBUG",
|
||||
embedding_batch_num=32,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"host": CHROMADB_HOST,
|
||||
"port": CHROMADB_PORT,
|
||||
"auth_token": CHROMADB_AUTH_TOKEN,
|
||||
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
||||
"auth_header_name": CHROMADB_AUTH_HEADER,
|
||||
"collection_settings": {
|
||||
"hnsw:space": "cosine",
|
||||
"hnsw:construction_ef": 128,
|
||||
"hnsw:search_ef": 128,
|
||||
"hnsw:M": 16,
|
||||
"hnsw:batch_size": 100,
|
||||
"hnsw:sync_threshold": 1000,
|
||||
if CHROMADB_USE_LOCAL_PERSISTENT:
|
||||
return LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=embedding_func_instance,
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
log_level="DEBUG",
|
||||
embedding_batch_num=32,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"local_path": CHROMADB_LOCAL_PATH,
|
||||
"collection_settings": {
|
||||
"hnsw:space": "cosine",
|
||||
"hnsw:construction_ef": 128,
|
||||
"hnsw:search_ef": 128,
|
||||
"hnsw:M": 16,
|
||||
"hnsw:batch_size": 100,
|
||||
"hnsw:sync_threshold": 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
return LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=embedding_func_instance,
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
log_level="DEBUG",
|
||||
embedding_batch_num=32,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"host": CHROMADB_HOST,
|
||||
"port": CHROMADB_PORT,
|
||||
"auth_token": CHROMADB_AUTH_TOKEN,
|
||||
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
||||
"auth_header_name": CHROMADB_AUTH_HEADER,
|
||||
"collection_settings": {
|
||||
"hnsw:space": "cosine",
|
||||
"hnsw:construction_ef": 128,
|
||||
"hnsw:search_ef": 128,
|
||||
"hnsw:M": 16,
|
||||
"hnsw:batch_size": 100,
|
||||
"hnsw:sync_threshold": 1000,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Run the initialization
|
||||
|
@@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorag Oracle
|
||||
OracleVectorDBStorage Oracle
|
||||
MongoVectorDBStorage MongoDB
|
||||
```
|
||||
|
||||
* DOC_STATUS_STORAGE:supported implement-name
|
||||
|
@@ -2,7 +2,7 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
from chromadb import HttpClient
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
@@ -49,31 +49,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
**user_collection_settings,
|
||||
}
|
||||
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
local_path = config.get("local_path", None)
|
||||
if local_path:
|
||||
self._client = PersistentClient(
|
||||
path=local_path,
|
||||
settings=Settings(
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.namespace,
|
||||
@@ -144,7 +154,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist(),
|
||||
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
|
||||
n_results=top_k * 2, # Request more results to allow for filtering
|
||||
include=["metadatas", "distances", "documents"],
|
||||
)
|
||||
|
@@ -4,6 +4,7 @@ import numpy as np
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
import asyncio
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
pm.install("pymongo")
|
||||
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
|
||||
from typing import Any, List, Tuple, Union
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
DocProcessingStatus,
|
||||
DocStatus,
|
||||
DocStatusStorage,
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
client = MongoClient(
|
||||
os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||
|
||||
self._collection_name = self.namespace
|
||||
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return self._data.find_one({"_id": id})
|
||||
return await self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
return await cursor.to_list()
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
existing_ids = [
|
||||
str(x["_id"])
|
||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||
]
|
||||
return set([s for s in data if s not in existing_ids])
|
||||
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||
return data - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
update_tasks = []
|
||||
for mode, items in data.items():
|
||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
||||
for k, v in items.items():
|
||||
key = f"{mode}_{k}"
|
||||
result = self._data.update_one(
|
||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||
data[mode][k]["_id"] = f"{mode}_{k}"
|
||||
update_tasks.append(
|
||||
self._data.update_one(
|
||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||
)
|
||||
)
|
||||
if result.upserted_id:
|
||||
logger.debug(f"\nInserted new document with key: {key}")
|
||||
data[mode][k]["_id"] = key
|
||||
await asyncio.gather(*update_tasks)
|
||||
else:
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
update_tasks = []
|
||||
for k, v in data.items():
|
||||
data[k]["_id"] = k
|
||||
update_tasks.append(
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
)
|
||||
await asyncio.gather(*update_tasks)
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
res = {}
|
||||
v = self._data.find_one({"_id": mode + "_" + id})
|
||||
v = await self._data.find_one({"_id": mode + "_" + id})
|
||||
if v:
|
||||
res[id] = v
|
||||
logger.debug(f"llm_response_cache find one by:{id}")
|
||||
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
|
||||
@dataclass
|
||||
class MongoDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
client = MongoClient(
|
||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as doc status {self.namespace}")
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
|
||||
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return self._data.find_one({"_id": id})
|
||||
return await self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
return await cursor.to_list()
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
existing_ids = [
|
||||
str(x["_id"])
|
||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||
]
|
||||
return set([s for s in data if s not in existing_ids])
|
||||
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||
return data - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
update_tasks = []
|
||||
for k, v in data.items():
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
data[k]["_id"] = k
|
||||
update_tasks.append(
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
)
|
||||
await asyncio.gather(*update_tasks)
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the collection"""
|
||||
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
||||
result = list(self._data.aggregate(pipeline))
|
||||
cursor = self._data.aggregate(pipeline)
|
||||
result = await cursor.to_list()
|
||||
counts = {}
|
||||
for doc in result:
|
||||
counts[doc["_id"]] = doc["count"]
|
||||
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
self, status: DocStatus
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents by status"""
|
||||
result = list(self._data.find({"status": status.value}))
|
||||
cursor = self._data.find({"status": status.value})
|
||||
result = await cursor.to_list()
|
||||
return {
|
||||
doc["_id"]: DocProcessingStatus(
|
||||
content=doc["content"],
|
||||
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.client = AsyncIOMotorClient(
|
||||
os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
self.db = self.client[
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
]
|
||||
self.collection = self.db[
|
||||
os.environ.get(
|
||||
"MONGO_KG_COLLECTION",
|
||||
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self.collection = database.get_collection(self._collection_name)
|
||||
|
||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
self, source_node_id: str
|
||||
) -> Union[List[Tuple[str, str]], None]:
|
||||
"""
|
||||
Return a list of (target_id, relation) for direct edges from source_node_id.
|
||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||
"""
|
||||
pipeline = [
|
||||
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
edges = result[0].get("edges", [])
|
||||
return [(e["target"], e["relation"]) for e in edges]
|
||||
return [(source_node_id, e["target"]) for e in edges]
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
1) Remove node’s doc entirely.
|
||||
1) Remove node's doc entirely.
|
||||
2) Remove inbound edges from any doc that references node_id.
|
||||
"""
|
||||
# Remove inbound edges from all other docs
|
||||
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
Placeholder for demonstration, raises NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# QUERY
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all existing node _id in the database
|
||||
Returns:
|
||||
[id1, id2, ...] # Alphabetically sorted id list
|
||||
"""
|
||||
# Use MongoDB's distinct and aggregation to get all unique labels
|
||||
pipeline = [
|
||||
{"$group": {"_id": "$_id"}}, # Group by _id
|
||||
{"$sort": {"_id": 1}}, # Sort alphabetically
|
||||
]
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
labels = []
|
||||
async for doc in cursor:
|
||||
labels.append(doc["_id"])
|
||||
return labels
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
Args:
|
||||
node_label: Label of the nodes to start from
|
||||
max_depth: Maximum depth of traversal (default: 5)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||
"""
|
||||
label = node_label
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
try:
|
||||
if label == "*":
|
||||
# Get all nodes and edges
|
||||
async for node_doc in self.collection.find({}):
|
||||
node_id = str(node_doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_doc.get("_id")],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in node_doc.items()
|
||||
if k not in ["_id", "edges"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Process edges
|
||||
for edge in node_doc.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
else:
|
||||
# Verify if starting node exists
|
||||
start_nodes = self.collection.find({"_id": label})
|
||||
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||
if not start_nodes_exist:
|
||||
logger.warning(f"Starting node with label {label} does not exist!")
|
||||
return result
|
||||
|
||||
# Use $graphLookup for traversal
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {"_id": label}
|
||||
}, # Start with nodes having the specified label
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._collection_name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_nodes",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
async for doc in self.collection.aggregate(pipeline):
|
||||
# Add the start node
|
||||
node_id = str(doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[
|
||||
doc.get(
|
||||
"_id",
|
||||
)
|
||||
],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in doc.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"edges",
|
||||
"connected_nodes",
|
||||
"depth",
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Add edges from start node
|
||||
for edge in doc.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
# Add connected nodes and their edges
|
||||
for connected in doc.get("connected_nodes", []):
|
||||
node_id = str(connected["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[connected.get("_id")],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in connected.items()
|
||||
if k not in ["_id", "edges", "depth"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Add edges from connected nodes
|
||||
for edge in connected.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"MongoDB query failed: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class MongoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
|
||||
# Ensure vector index exists
|
||||
self.create_vector_index(uri, database.name, self._collection_name)
|
||||
|
||||
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
||||
"""Creates an Atlas Vector Search index."""
|
||||
client = MongoClient(uri)
|
||||
collection = client.get_database(database_name).get_collection(
|
||||
self._collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
"fields": [
|
||||
{
|
||||
"type": "vector",
|
||||
"numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
|
||||
"path": "vector",
|
||||
"similarity": "cosine", # Options: euclidean, cosine, dotProduct
|
||||
}
|
||||
]
|
||||
},
|
||||
name="vector_knn_index",
|
||||
type="vectorSearch",
|
||||
)
|
||||
|
||||
collection.create_search_index(search_index_model)
|
||||
logger.info("Vector index created successfully.")
|
||||
|
||||
except PyMongoError as _:
|
||||
logger.debug("vector index already exist")
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You are inserting an empty data set to vector DB")
|
||||
return []
|
||||
|
||||
list_data = [
|
||||
{
|
||||
"_id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
async def wrapped_task(batch):
|
||||
result = await self.embedding_func(batch)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||
pbar = tqdm_async(
|
||||
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||
)
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
|
||||
|
||||
update_tasks = []
|
||||
for doc in list_data:
|
||||
update_tasks.append(
|
||||
self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
|
||||
)
|
||||
await asyncio.gather(*update_tasks)
|
||||
|
||||
return list_data
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||
query_vector = embedding[0].tolist()
|
||||
|
||||
# Define the aggregation pipeline with the converted query vector
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": "vector_knn_index", # Ensure this matches the created index name
|
||||
"path": "vector",
|
||||
"queryVector": query_vector,
|
||||
"numCandidates": 100, # Adjust for performance
|
||||
"limit": top_k,
|
||||
}
|
||||
},
|
||||
{"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
{"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
|
||||
{"$project": {"vector": 0}},
|
||||
]
|
||||
|
||||
# Execute the aggregation pipeline
|
||||
cursor = self._data.aggregate(pipeline)
|
||||
results = await cursor.to_list()
|
||||
|
||||
# Format and return the results
|
||||
return [
|
||||
{**doc, "id": doc["_id"], "distance": doc.get("score", None)}
|
||||
for doc in results
|
||||
]
|
||||
|
||||
|
||||
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||
"""Check if the collection exists. if not, create it."""
|
||||
client = MongoClient(uri)
|
||||
database = client.get_database(database_name)
|
||||
|
||||
collection_names = database.list_collection_names()
|
||||
|
||||
if collection_name not in collection_names:
|
||||
database.create_collection(collection_name)
|
||||
logger.info(f"Created collection: {collection_name}")
|
||||
else:
|
||||
logger.debug(f"Collection '{collection_name}' already exists.")
|
||||
|
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
async def _label_exists(self, label: str) -> bool:
|
||||
"""Check if a label exists in the Neo4j database."""
|
||||
query = "CALL db.labels() YIELD label RETURN label"
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
result = await session.run(query)
|
||||
labels = [record["label"] for record in await result.data()]
|
||||
return label in labels
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking label existence: {e}")
|
||||
return False
|
||||
|
||||
async def _ensure_label(self, label: str) -> str:
|
||||
"""Ensure a label exists by validating it."""
|
||||
clean_label = label.strip('"')
|
||||
if not await self._label_exists(clean_label):
|
||||
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
|
||||
return clean_label
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = await self._ensure_label(node_id)
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""Get node by its label identifier.
|
||||
|
||||
Args:
|
||||
node_id: The node label to look up
|
||||
|
||||
Returns:
|
||||
dict: Node properties if found
|
||||
None: If node not found
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
entity_name_label = node_id.strip('"')
|
||||
entity_name_label = await self._ensure_label(node_id)
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
"""Find edge between two nodes identified by their labels.
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
source_node_id (str): Label of the source node
|
||||
target_node_id (str): Label of the target node
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||
None: If error occurs
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
""".format(
|
||||
entity_name_label_source=entity_name_label_source,
|
||||
entity_name_label_target=entity_name_label_target,
|
||||
)
|
||||
try:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record:
|
||||
result = dict(record["edge_properties"])
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
""".format(
|
||||
entity_name_label_source=entity_name_label_source,
|
||||
entity_name_label_target=entity_name_label_target,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record and "edge_properties" in record:
|
||||
try:
|
||||
result = dict(record["edge_properties"])
|
||||
# Ensure required keys exist with defaults
|
||||
required_keys = {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"target_id": None,
|
||||
}
|
||||
for key, default_value in required_keys.items():
|
||||
if key not in result:
|
||||
result[key] = default_value
|
||||
logger.warning(
|
||||
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||
f"missing {key}, using default: {default_value}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||
)
|
||||
return result
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.error(
|
||||
f"Error processing edge properties between {entity_name_label_source} "
|
||||
f"and {entity_name_label_target}: {str(e)}"
|
||||
)
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||
)
|
||||
# Return default edge properties when no edge found
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||
)
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
node_label = source_node_id.strip('"')
|
||||
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
label = node_id.strip('"')
|
||||
label = await self._ensure_label(node_id)
|
||||
properties = node_data
|
||||
|
||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
target_node_id (str): Label of the target node (used as identifier)
|
||||
edge_data (dict): Dictionary of properties to set on the edge
|
||||
"""
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
source_label = await self._ensure_label(source_node_id)
|
||||
target_label = await self._ensure_label(target_node_id)
|
||||
edge_properties = edge_data
|
||||
|
||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
MATCH (source:`{source_node_label}`)
|
||||
MATCH (source:`{source_label}`)
|
||||
WITH source
|
||||
MATCH (target:`{target_node_label}`)
|
||||
MATCH (target:`{target_label}`)
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += $properties
|
||||
RETURN r
|
||||
"""
|
||||
await tx.run(query, properties=edge_properties)
|
||||
result = await tx.run(query, properties=edge_properties)
|
||||
record = await result.single()
|
||||
logger.debug(
|
||||
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
|
||||
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
@@ -78,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
@@ -142,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
@@ -162,6 +164,7 @@ STORAGES = {
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
|
@@ -239,25 +239,65 @@ async def _merge_edges_then_upsert(
|
||||
|
||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||
already_weights.append(already_edge["weight"])
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_description.append(already_edge["description"])
|
||||
already_keywords.extend(
|
||||
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
# Handle the case where get_edge returns None or missing fields
|
||||
if already_edge:
|
||||
# Get weight with default 0.0 if missing
|
||||
if "weight" in already_edge:
|
||||
already_weights.append(already_edge["weight"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Edge between {src_id} and {tgt_id} missing weight field"
|
||||
)
|
||||
already_weights.append(0.0)
|
||||
|
||||
# Get source_id with empty string default if missing or None
|
||||
if "source_id" in already_edge and already_edge["source_id"] is not None:
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
# Get description with empty string default if missing or None
|
||||
if (
|
||||
"description" in already_edge
|
||||
and already_edge["description"] is not None
|
||||
):
|
||||
already_description.append(already_edge["description"])
|
||||
|
||||
# Get keywords with empty string default if missing or None
|
||||
if "keywords" in already_edge and already_edge["keywords"] is not None:
|
||||
already_keywords.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
# Process edges_data with None checks
|
||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||
description = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
||||
sorted(
|
||||
set(
|
||||
[dp["description"] for dp in edges_data if dp.get("description")]
|
||||
+ already_description
|
||||
)
|
||||
)
|
||||
)
|
||||
keywords = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
||||
sorted(
|
||||
set(
|
||||
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
||||
+ already_keywords
|
||||
)
|
||||
)
|
||||
)
|
||||
source_id = GRAPH_FIELD_SEP.join(
|
||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
||||
set(
|
||||
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||||
+ already_source_ids
|
||||
)
|
||||
)
|
||||
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
class GPTKeywordExtractionFormat(BaseModel):
|
||||
@@ -17,7 +17,7 @@ class KnowledgeGraphNode(BaseModel):
|
||||
|
||||
class KnowledgeGraphEdge(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
type: Optional[str]
|
||||
source: str # id of source node
|
||||
target: str # id of target node
|
||||
properties: dict[str, Any] # anything else goes here
|
||||
|
@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
|
||||
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
||||
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
||||
<PropertyRow name={'Id'} value={edge.id} />
|
||||
<PropertyRow name={'Type'} value={edge.type} />
|
||||
{edge.type && <PropertyRow name={'Type'} value={edge.type} />}
|
||||
<PropertyRow
|
||||
name={'Source'}
|
||||
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
||||
|
@@ -24,7 +24,7 @@ const validateGraph = (graph: RawGraph) => {
|
||||
}
|
||||
|
||||
for (const edge of graph.edges) {
|
||||
if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) {
|
||||
if (!edge.id || !edge.source || !edge.target) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
|
||||
if (source !== undefined && source !== undefined) {
|
||||
const sourceNode = rawData.nodes[source]
|
||||
const targetNode = rawData.nodes[target]
|
||||
if (!sourceNode) {
|
||||
console.error(`Source node ${edge.source} is undefined`)
|
||||
continue
|
||||
}
|
||||
if (!targetNode) {
|
||||
console.error(`Target node ${edge.target} is undefined`)
|
||||
continue
|
||||
}
|
||||
sourceNode.degree += 1
|
||||
targetNode.degree += 1
|
||||
}
|
||||
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
||||
|
||||
for (const rawEdge of rawGraph?.edges ?? []) {
|
||||
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
||||
label: rawEdge.type
|
||||
label: rawEdge.type || undefined
|
||||
})
|
||||
}
|
||||
|
||||
|
@@ -19,7 +19,7 @@ export type RawEdgeType = {
|
||||
id: string
|
||||
source: string
|
||||
target: string
|
||||
type: string
|
||||
type?: string
|
||||
properties: Record<string, any>
|
||||
|
||||
dynamicId: string
|
||||
|
Reference in New Issue
Block a user