Merge pull request #795 from YanSte/make-clear-what-implemented-or-not
Enhancing ABC Enforcement and Standardizing Subclass Implementations
This commit is contained in:
132
lightrag/base.py
132
lightrag/base.py
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
@@ -82,138 +83,130 @@ class QueryParam:
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageNameSpace:
|
||||
class StorageNameSpace(ABC):
|
||||
namespace: str
|
||||
global_config: dict[str, Any]
|
||||
|
||||
@abstractmethod
|
||||
async def index_done_callback(self) -> None:
|
||||
"""Commit the storage operations after indexing"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace):
|
||||
class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
embedding_func: EmbeddingFunc
|
||||
cosine_better_than_threshold: float = field(default=0.2)
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
"""Query the vector storage and retrieve top_k results."""
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Use 'content' field from value for embedding, use key as id.
|
||||
If embedding_func is None, use 'embedding' field from value
|
||||
"""
|
||||
raise NotImplementedError
|
||||
"""Insert or update vectors in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
"""Delete a single entity by its name."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
"""Delete relations for a given entity."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
class BaseKVStorage(StorageNameSpace, ABC):
|
||||
embedding_func: EmbeddingFunc
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
"""Get value by id"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
"""Get values by ids"""
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
@abstractmethod
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return un-exist keys"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Upsert data"""
|
||||
|
||||
@abstractmethod
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
"""Drop the storage"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
"""Check if a node exists in the graph."""
|
||||
class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
embedding_func: EmbeddingFunc
|
||||
|
||||
@abstractmethod
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Check if an edge exists in the graph."""
|
||||
"""Check if an edge exists in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get the degree of a node."""
|
||||
"""Get the degree of a node."""
|
||||
|
||||
@abstractmethod
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get the degree of an edge."""
|
||||
"""Get the degree of an edge."""
|
||||
|
||||
@abstractmethod
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get a node by its id."""
|
||||
"""Get a node by its id."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get an edge by its source and target node ids."""
|
||||
"""Get an edge by its source and target node ids."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get all edges connected to a node."""
|
||||
"""Get all edges connected to a node."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Upsert a node into the graph."""
|
||||
"""Upsert a node into the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Upsert an edge into the graph."""
|
||||
"""Upsert an edge into the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Delete a node from the graph."""
|
||||
"""Delete a node from the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Embed nodes using an algorithm."""
|
||||
"""Embed nodes using an algorithm."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||
|
||||
"""Get all labels in the graph."""
|
||||
"""Get all labels in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get a knowledge graph of a node."""
|
||||
"""Get a knowledge graph of a node."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
|
||||
|
||||
class DocStatus(str, Enum):
|
||||
"""Document processing status enum"""
|
||||
class DocStatus(StrEnum):
|
||||
"""Document processing status"""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
@@ -245,19 +238,16 @@ class DocProcessingStatus:
|
||||
"""Additional metadata"""
|
||||
|
||||
|
||||
class DocStatusStorage(BaseKVStorage):
|
||||
@dataclass
|
||||
class DocStatusStorage(BaseKVStorage, ABC):
|
||||
"""Base class for document status storage"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_docs_by_status(
|
||||
self, status: DocStatus
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
||||
"""Updates the status of a document. By default, it calls upsert."""
|
||||
await self.upsert(data)
|
||||
|
@@ -5,19 +5,11 @@ import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
pm.install("psycopg-pool")
|
||||
pm.install("psycopg[binary,pool]")
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -35,6 +27,23 @@ if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
pm.install("psycopg-pool")
|
||||
pm.install("psycopg[binary,pool]")
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`."
|
||||
)
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
@@ -53,6 +62,7 @@ class AGEQueryException(Exception):
|
||||
return self.details
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class AGEStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -98,9 +108,6 @@ class AGEStorage(BaseGraphStorage):
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -396,7 +403,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edge_exists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = """
|
||||
MATCH (n:`{label}`) RETURN n
|
||||
@@ -454,17 +461,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
@@ -488,7 +485,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
@@ -526,7 +523,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((AGEQueryException,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the AGE database.
|
||||
|
||||
@@ -562,8 +559,8 @@ class AGEStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((AGEQueryException,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
@@ -619,3 +616,23 @@ class AGEStorage(BaseGraphStorage):
|
||||
yield connection
|
||||
finally:
|
||||
await self._driver.putconn(connection)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# AGES handles persistence automatically
|
||||
pass
|
||||
|
@@ -1,19 +1,29 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("chromadb"):
|
||||
pm.install("chromadb")
|
||||
|
||||
try:
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`chromadb` library is not installed. Please install it via pip: `pip install chromadb`."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
"""ChromaDB vector storage implementation."""
|
||||
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
try:
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -102,7 +112,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"ChromaDB initialization failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not data:
|
||||
logger.warning("Empty data provided to vector DB")
|
||||
return []
|
||||
@@ -151,7 +161,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
try:
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
@@ -183,6 +193,12 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during ChromaDB query: {str(e)}")
|
||||
raise
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
# ChromaDB handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import faiss
|
||||
from typing import Any, final
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
@@ -15,7 +17,19 @@ from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
try:
|
||||
import faiss
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`faiss` library is not installed. Please install it via pip: `pip install faiss`."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
@@ -23,8 +37,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||
"""
|
||||
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Grab config values if available
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -57,7 +69,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# Attempt to load an existing index + metadata from disk
|
||||
self._load_faiss_index()
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Insert or update vectors in the Faiss index.
|
||||
|
||||
@@ -147,7 +159,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
@@ -210,11 +222,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
"""
|
||||
Delete a single entity by computing its hashed ID
|
||||
the same way your code does it with `compute_mdhash_id`.
|
||||
"""
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||
await self.delete([entity_id])
|
||||
@@ -234,12 +242,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._remove_faiss_ids(relations)
|
||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""
|
||||
Called after indexing is done (save Faiss index + metadata).
|
||||
"""
|
||||
async def index_done_callback(self) -> None:
|
||||
self._save_faiss_index()
|
||||
logger.info("Faiss index saved successfully.")
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Internal helper methods
|
||||
|
@@ -3,11 +3,11 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -15,11 +15,22 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
try:
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`gremlin` library is not installed. Please install it via pip: `pip install gremlin`."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class GremlinStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -76,8 +87,9 @@ class GremlinStorage(BaseGraphStorage):
|
||||
if self._driver:
|
||||
self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
# Gremlin handles persistence automatically
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _to_value_map(value: Any) -> str:
|
||||
@@ -190,7 +202,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
return result[0]["has_edge"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
entity_name = GremlinStorage._fix_name(node_id)
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
@@ -252,17 +264,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given names
|
||||
|
||||
Args:
|
||||
source_node_id (str): Name of the source nodes
|
||||
target_node_id (str): Name of the target nodes
|
||||
|
||||
Returns:
|
||||
dict|None: Dict of found edge properties, or None if not found
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||
query = f"""g
|
||||
@@ -286,11 +288,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
)
|
||||
return edge_properties
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its name.
|
||||
:return: List of tuples containing edge sources and targets
|
||||
"""
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
node_name = GremlinStorage._fix_name(source_node_id)
|
||||
query = f"""g
|
||||
.E()
|
||||
@@ -316,7 +314,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((GremlinServerError,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the Gremlin graph.
|
||||
|
||||
@@ -357,8 +355,8 @@ class GremlinStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((GremlinServerError,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their names.
|
||||
|
||||
@@ -397,3 +395,19 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,56 +1,6 @@
|
||||
"""
|
||||
JsonDocStatus Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
from lightrag.base import (
|
||||
DocProcessingStatus,
|
||||
@@ -64,6 +14,7 @@ from lightrag.utils import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""JSON implementation of document status storage"""
|
||||
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
return set(data) - set(self._data.keys())
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
result: list[dict[str, Any]] = []
|
||||
@@ -88,7 +39,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status: 0 for status in DocStatus}
|
||||
counts = {status.value: 0 for status in DocStatus}
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
@@ -96,23 +47,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
async def get_docs_by_status(
|
||||
self, status: DocStatus
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""all documents with a specific status"""
|
||||
"""Get all documents with a specific status"""
|
||||
return {
|
||||
k: DocProcessingStatus(**v)
|
||||
for k, v in self._data.items()
|
||||
if v["status"] == status
|
||||
if v["status"] == status.value
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save data to file after indexing"""
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
"""Update or insert document status
|
||||
|
||||
Args:
|
||||
data: Dictionary of document IDs and their status data
|
||||
"""
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
|
||||
@@ -120,7 +65,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
"""Delete document status by IDs"""
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, final
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
@@ -13,6 +13,7 @@ from lightrag.utils import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
@@ -22,10 +23,10 @@ class JsonKVStorage(BaseKVStorage):
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
return self._data.get(id)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
@@ -38,8 +39,8 @@ class JsonKVStorage(BaseKVStorage):
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
return set(data) - set(self._data.keys())
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@@ -10,17 +11,21 @@ import configparser
|
||||
|
||||
if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`."
|
||||
) from e
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: MilvusClient, collection_name: str, **kwargs
|
||||
@@ -71,7 +76,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -106,7 +111,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
@@ -123,3 +128,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
||||
for dp in results[0]
|
||||
]
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Milvus handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,22 +1,11 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
import asyncio
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
pm.install("pymongo")
|
||||
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
from typing import Any, List, Union, final
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -29,12 +18,29 @@ from ..base import (
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
pm.install("pymongo")
|
||||
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`."
|
||||
) from e
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
@@ -60,17 +66,17 @@ class MongoKVStorage(BaseKVStorage):
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
return await self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
return await cursor.to_list()
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
|
||||
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||
return data - existing_ids
|
||||
return keys - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
@@ -107,11 +113,16 @@ class MongoKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the collection"""
|
||||
await self._data.drop()
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
@@ -191,7 +202,12 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
for doc in result
|
||||
}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
@@ -429,7 +445,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Return the full node document (including "edges"), or None if missing.
|
||||
"""
|
||||
@@ -437,11 +453,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Return the first edge dict from source_node_id to target_node_id if it exists.
|
||||
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
@@ -467,9 +479,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
return e
|
||||
return None
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[List[Tuple[str, str]], None]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||
@@ -503,7 +513,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Insert or update a node document. If new, create an empty edges array.
|
||||
"""
|
||||
@@ -513,8 +523,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||
@@ -540,7 +550,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
1) Remove node's doc entirely.
|
||||
2) Remove inbound edges from any doc that references node_id.
|
||||
@@ -557,7 +567,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
"""
|
||||
Placeholder for demonstration, raises NotImplementedError.
|
||||
"""
|
||||
@@ -759,11 +771,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -828,7 +843,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
except PyMongoError as _:
|
||||
logger.debug("vector index already exist")
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You are inserting an empty data set to vector DB")
|
||||
@@ -871,7 +886,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
return list_data
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func([query])
|
||||
@@ -905,6 +920,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
for doc in results
|
||||
]
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||
"""Check if the collection exists. if not, create it."""
|
||||
|
@@ -1,80 +1,35 @@
|
||||
"""
|
||||
NanoVectorDB Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
|
||||
import pipmaster as pm
|
||||
from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
try:
|
||||
from nano_vectordb import NanoVectorDB
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
@@ -95,7 +50,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -139,7 +94,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
@@ -176,7 +131,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
@@ -211,7 +166,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
# Protect file write operation
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._save_lock:
|
||||
self._client.save()
|
||||
|
@@ -3,20 +3,11 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
import pipmaster as pm
|
||||
from typing import Any, List, Dict, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -27,12 +18,29 @@ from tenacity import (
|
||||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
|
||||
try:
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`neo4j` library is not installed. Please install it via pip: `pip install neo4j`."
|
||||
) from e
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -140,8 +148,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
# Noe4J handles persistence automatically
|
||||
pass
|
||||
|
||||
async def _label_exists(self, label: str) -> bool:
|
||||
"""Check if a label exists in the Neo4j database."""
|
||||
@@ -191,7 +200,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier.
|
||||
|
||||
Args:
|
||||
@@ -252,17 +261,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""Find edge between two nodes identified by their labels.
|
||||
|
||||
Args:
|
||||
source_node_id (str): Label of the source node
|
||||
target_node_id (str): Label of the target node
|
||||
|
||||
Returns:
|
||||
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||
None: If error occurs
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
try:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
@@ -321,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
node_label = source_node_id.strip('"')
|
||||
|
||||
"""
|
||||
@@ -364,7 +363,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the Neo4j database.
|
||||
|
||||
@@ -405,8 +404,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
@@ -603,7 +602,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
await traverse(label, 0)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> List[str]:
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all existing node labels in the database
|
||||
Returns:
|
||||
@@ -627,3 +626,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async for record in result:
|
||||
labels.append(record["label"])
|
||||
return labels
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,61 +1,12 @@
|
||||
"""
|
||||
NetworkX Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
from typing import Any, cast, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
)
|
||||
@@ -64,7 +15,15 @@ from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`networkx` library is not installed. Please install it via pip: `pip install networkx`."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -142,7 +101,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
@@ -151,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
@@ -162,35 +121,32 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
) -> dict[str, str] | None:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete a node from the graph based on the specified node_id.
|
||||
|
||||
:param node_id: The node_id to delete
|
||||
"""
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
logger.info(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -226,3 +182,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for source, target in edges:
|
||||
if self._graph.has_edge(source, target):
|
||||
self._graph.remove_edge(source, target)
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
@@ -4,16 +4,11 @@ import asyncio
|
||||
# import html
|
||||
# import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
|
||||
import oracledb
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -23,6 +18,19 @@ from ..base import (
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
try:
|
||||
import oracledb
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`oracledb` library is not installed. Please install it via pip: `pip install oracledb`."
|
||||
) from e
|
||||
|
||||
|
||||
class OracleDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -169,6 +177,7 @@ class OracleDB:
|
||||
raise
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -181,7 +190,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get doc_full data based on id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
@@ -232,7 +241,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
return res
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=namespace_to_table_name(self.namespace),
|
||||
@@ -248,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
return set(keys)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
list_data = [
|
||||
{
|
||||
@@ -307,20 +316,17 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
# Oracle handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
@@ -330,16 +336,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""向向量数据库中插入数据"""
|
||||
pass
|
||||
|
||||
async def index_done_callback(self):
|
||||
pass
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
"""从向量数据库中查询数据"""
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
# 转换精度
|
||||
@@ -359,21 +357,29 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# print("vector search result:",results)
|
||||
return results
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Oracles handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||
|
||||
#################### insert method ################
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""插入或更新节点"""
|
||||
# print("go into upsert node method")
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
@@ -406,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
"""插入或更新边"""
|
||||
# print("go into upsert edge method")
|
||||
source_name = source_node_id
|
||||
@@ -446,8 +452,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
await self.db.execute(merge_sql, data)
|
||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""为节点生成向量"""
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -464,11 +471,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""写入graphhml图文件"""
|
||||
logger.info(
|
||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
||||
)
|
||||
async def index_done_callback(self) -> None:
|
||||
# Oracles handles persistence automatically
|
||||
pass
|
||||
|
||||
#################### query method #################
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
@@ -486,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
return False
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""根据源和目标节点id检查边是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_edge"]
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
@@ -503,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
return False
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""根据节点id获取节点的度"""
|
||||
SQL = SQL_TEMPLATES["node_degree"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(SQL)
|
||||
@@ -521,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
# print("Edge degree",degree)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""根据节点id获取节点数据"""
|
||||
SQL = SQL_TEMPLATES["get_node"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
@@ -537,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""根据源和目标节点id获取边"""
|
||||
) -> dict[str, str] | None:
|
||||
SQL = SQL_TEMPLATES["get_edge"]
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
@@ -553,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
"""根据节点id获取节点的所有边"""
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if await self.has_node(source_node_id):
|
||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||
@@ -590,6 +591,17 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
if res:
|
||||
return res
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -4,24 +4,19 @@ import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
import sys
|
||||
|
||||
import asyncpg
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -39,6 +34,20 @@ if sys.platform.startswith("win"):
|
||||
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`."
|
||||
) from e
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -175,6 +184,7 @@ class PostgreSQLDB:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -185,7 +195,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get doc_full data by id."""
|
||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
@@ -240,7 +250,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Filter out duplicated content"""
|
||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=namespace_to_table_name(self.namespace),
|
||||
@@ -261,7 +271,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
print(params)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
pass
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||
@@ -287,20 +297,17 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into postgresql db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -352,7 +359,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
}
|
||||
return upsert_sql, data
|
||||
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -398,12 +405,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
logger.info("vector data had been saved into postgresql db!")
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
"""从向量数据库中查询数据"""
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
@@ -417,23 +420,31 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
results = await self.db.query(sql, params=params, multirows=True)
|
||||
return results
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
keys = ",".join([f"'{_id}'" for _id in data])
|
||||
keys = ",".join([f"'{_id}'" for _id in keys])
|
||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||
result = await self.db.query(sql, multirows=True)
|
||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||
if result is None:
|
||||
return set(data)
|
||||
return set(keys)
|
||||
else:
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
return set(keys) - existed
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
@@ -452,6 +463,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
updated_at=result[0]["updated_at"],
|
||||
)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||
@@ -470,7 +484,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
) -> Dict[str, DocProcessingStatus]:
|
||||
"""all documents with a specific status"""
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
params = {"workspace": self.db.workspace, "status": status.value}
|
||||
result = await self.db.query(sql, params, True)
|
||||
return {
|
||||
element["id"]: DocProcessingStatus(
|
||||
@@ -485,11 +499,11 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
for element in result
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
||||
logger.info("Doc status had been saved into postgresql db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Update or insert document status
|
||||
|
||||
Args:
|
||||
@@ -520,31 +534,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
)
|
||||
return data
|
||||
|
||||
async def update_doc_status(self, data: dict[str, dict]) -> None:
|
||||
"""
|
||||
Updates only the document status, chunk count, and updated timestamp.
|
||||
|
||||
This method ensures that only relevant fields are updated instead of overwriting
|
||||
the entire document record. If `updated_at` is not provided, the database will
|
||||
automatically use the current timestamp.
|
||||
"""
|
||||
sql = """
|
||||
UPDATE LIGHTRAG_DOC_STATUS
|
||||
SET status = $3,
|
||||
chunks_count = $4,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE workspace = $1 AND id = $2
|
||||
"""
|
||||
for k, v in data.items():
|
||||
_data = {
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"status": v["status"].value, # Convert Enum to string
|
||||
"chunks_count": v.get(
|
||||
"chunks_count", -1
|
||||
), # Default to -1 if not provided
|
||||
}
|
||||
await self.db.execute(sql, _data)
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PGGraphQueryException(Exception):
|
||||
@@ -565,11 +556,9 @@ class PGGraphQueryException(Exception):
|
||||
return self.details
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
@@ -580,8 +569,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
||||
@@ -811,7 +801,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edge_exists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
@@ -866,17 +856,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_id (str): Label of the source nodes
|
||||
target_node_id (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||
|
||||
@@ -900,7 +880,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
@@ -948,14 +928,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
"""
|
||||
Upsert a node in the AGE database.
|
||||
|
||||
Args:
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
properties = node_data
|
||||
|
||||
@@ -986,8 +959,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
@@ -1029,6 +1002,22 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
NAMESPACE_TABLE_MAP = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@@ -7,16 +8,24 @@ import hashlib
|
||||
import uuid
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("qdrant_client"):
|
||||
pm.install("qdrant_client")
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
try:
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`."
|
||||
)
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
@@ -47,10 +56,9 @@ def compute_mdhash_id_for_qdrant(
|
||||
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: QdrantClient, collection_name: str, **kwargs
|
||||
@@ -85,7 +93,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
),
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
@@ -130,7 +138,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
@@ -143,3 +151,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(f"query result: {results}")
|
||||
|
||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Qdrant handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Union
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
@@ -19,6 +19,7 @@ config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class RedisKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
@@ -28,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
logger.info(f"Use Redis as KV {self.namespace}")
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
|
||||
@@ -39,16 +40,16 @@ class RedisKVStorage(BaseKVStorage):
|
||||
results = await pipe.execute()
|
||||
return [json.loads(result) if result else None for result in results]
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
pipe = self._redis.pipeline()
|
||||
for key in data:
|
||||
for key in keys:
|
||||
pipe.exists(f"{self.namespace}:{key}")
|
||||
results = await pipe.execute()
|
||||
|
||||
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
||||
return set(data) - existing_ids
|
||||
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
|
||||
return set(keys) - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
pipe = self._redis.pipeline()
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||
@@ -61,3 +62,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
@@ -1,9 +1,18 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymysql"):
|
||||
@@ -11,12 +20,13 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from tqdm import tqdm
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`."
|
||||
) from e
|
||||
|
||||
|
||||
class TiDB:
|
||||
@@ -99,6 +109,7 @@ class TiDB:
|
||||
raise
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -110,7 +121,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Fetch doc_full data by id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"id": id}
|
||||
@@ -125,8 +136,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
)
|
||||
return await self.db.query(SQL, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=namespace_to_table_name(self.namespace),
|
||||
id_field=namespace_to_id(self.namespace),
|
||||
@@ -147,7 +157,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
return data
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
@@ -200,20 +210,17 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
await self.db.execute(merge_sql, data)
|
||||
return left_data
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
@@ -227,7 +234,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Search from tidb vector"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
@@ -249,7 +256,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
###### INSERT entities And relationships ######
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
# ignore, upsert in TiDBKVStorage already
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -332,7 +339,18 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -342,7 +360,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
#################### upsert method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
@@ -373,7 +391,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
@@ -409,7 +427,9 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -442,14 +462,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
sql = SQL_TEMPLATES["get_node"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
) -> dict[str, str] | None:
|
||||
sql = SQL_TEMPLATES["get_edge"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
@@ -458,9 +478,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
sql = SQL_TEMPLATES["get_node_edges"]
|
||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||
res = await self.db.query(sql, param, multirows=True)
|
||||
@@ -470,6 +488,21 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
return []
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -674,7 +674,7 @@ class LightRAG:
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING,
|
||||
"status": DocStatus.PENDING.value,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
@@ -745,7 +745,7 @@ class LightRAG:
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"status": DocStatus.PROCESSING.value,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
@@ -779,10 +779,10 @@ class LightRAG:
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await self.doc_status.update_doc_status(
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"status": DocStatus.PROCESSED.value,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
@@ -796,10 +796,10 @@ class LightRAG:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||
await self.doc_status.update_doc_status(
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"status": DocStatus.FAILED.value,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
|
Reference in New Issue
Block a user