Merge pull request #795 from YanSte/make-clear-what-implemented-or-not
Enhancing ABC Enforcement and Standardizing Subclass Implementations
This commit is contained in:
132
lightrag/base.py
132
lightrag/base.py
@@ -1,9 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import StrEnum
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Literal,
|
Literal,
|
||||||
@@ -82,138 +83,130 @@ class QueryParam:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace:
|
class StorageNameSpace(ABC):
|
||||||
namespace: str
|
namespace: str
|
||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
"""Commit the storage operations after indexing"""
|
"""Commit the storage operations after indexing"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace, ABC):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
cosine_better_than_threshold: float = field(default=0.2)
|
||||||
meta_fields: set[str] = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
raise NotImplementedError
|
"""Query the vector storage and retrieve top_k results."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""Use 'content' field from value for embedding, use key as id.
|
"""Insert or update vectors in the storage."""
|
||||||
If embedding_func is None, use 'embedding' field from value
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def delete_entity(self, entity_name: str) -> None:
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
"""Delete a single entity by its name"""
|
"""Delete a single entity by its name."""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
"""Delete relations for a given entity by scanning metadata"""
|
"""Delete relations for a given entity."""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(StorageNameSpace):
|
class BaseKVStorage(StorageNameSpace, ABC):
|
||||||
embedding_func: EmbeddingFunc | None = None
|
embedding_func: EmbeddingFunc
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
raise NotImplementedError
|
"""Get value by id"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
raise NotImplementedError
|
"""Get values by ids"""
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
@abstractmethod
|
||||||
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Return un-exist keys"""
|
"""Return un-exist keys"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
@abstractmethod
|
||||||
raise NotImplementedError
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
"""Upsert data"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
raise NotImplementedError
|
"""Drop the storage"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace):
|
class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
embedding_func: EmbeddingFunc | None = None
|
embedding_func: EmbeddingFunc
|
||||||
"""Check if a node exists in the graph."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
"""Check if an edge exists in the graph."""
|
||||||
|
|
||||||
"""Check if an edge exists in the graph."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
"""Get the degree of a node."""
|
||||||
|
|
||||||
"""Get the degree of a node."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
raise NotImplementedError
|
"""Get the degree of an edge."""
|
||||||
|
|
||||||
"""Get the degree of an edge."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
raise NotImplementedError
|
"""Get a node by its id."""
|
||||||
|
|
||||||
"""Get a node by its id."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
"""Get an edge by its source and target node ids."""
|
||||||
|
|
||||||
"""Get an edge by its source and target node ids."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
"""Get all edges connected to a node."""
|
||||||
|
|
||||||
"""Get all edges connected to a node."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
raise NotImplementedError
|
"""Upsert a node into the graph."""
|
||||||
|
|
||||||
"""Upsert a node into the graph."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
raise NotImplementedError
|
"""Upsert an edge into the graph."""
|
||||||
|
|
||||||
"""Upsert an edge into the graph."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
"""Delete a node from the graph."""
|
||||||
|
|
||||||
"""Delete a node from the graph."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
"""Embed nodes using an algorithm."""
|
||||||
|
|
||||||
"""Embed nodes using an algorithm."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def embed_nodes(
|
async def embed_nodes(
|
||||||
self, algorithm: str
|
self, algorithm: str
|
||||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
"""Get all labels in the graph."""
|
||||||
|
|
||||||
"""Get all labels in the graph."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
raise NotImplementedError
|
"""Get a knowledge graph of a node."""
|
||||||
|
|
||||||
"""Get a knowledge graph of a node."""
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
raise NotImplementedError
|
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||||
|
|
||||||
|
|
||||||
class DocStatus(str, Enum):
|
class DocStatus(StrEnum):
|
||||||
"""Document processing status enum"""
|
"""Document processing status"""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
PROCESSING = "processing"
|
PROCESSING = "processing"
|
||||||
@@ -245,19 +238,16 @@ class DocProcessingStatus:
|
|||||||
"""Additional metadata"""
|
"""Additional metadata"""
|
||||||
|
|
||||||
|
|
||||||
class DocStatusStorage(BaseKVStorage):
|
@dataclass
|
||||||
|
class DocStatusStorage(BaseKVStorage, ABC):
|
||||||
"""Base class for document status storage"""
|
"""Base class for document status storage"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents with a specific status"""
|
"""Get all documents with a specific status"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
|
||||||
"""Updates the status of a document. By default, it calls upsert."""
|
|
||||||
await self.upsert(data)
|
|
||||||
|
@@ -5,19 +5,11 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
||||||
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
if not pm.is_installed("psycopg-pool"):
|
|
||||||
pm.install("psycopg-pool")
|
|
||||||
pm.install("psycopg[binary,pool]")
|
|
||||||
if not pm.is_installed("asyncpg"):
|
|
||||||
pm.install("asyncpg")
|
|
||||||
|
|
||||||
|
|
||||||
import psycopg
|
|
||||||
from psycopg.rows import namedtuple_row
|
|
||||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
@@ -35,6 +27,23 @@ if sys.platform.startswith("win"):
|
|||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
if not pm.is_installed("psycopg-pool"):
|
||||||
|
pm.install("psycopg-pool")
|
||||||
|
pm.install("psycopg[binary,pool]")
|
||||||
|
|
||||||
|
if not pm.is_installed("asyncpg"):
|
||||||
|
pm.install("asyncpg")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import psycopg
|
||||||
|
from psycopg.rows import namedtuple_row
|
||||||
|
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AGEQueryException(Exception):
|
class AGEQueryException(Exception):
|
||||||
"""Exception for the AGE queries."""
|
"""Exception for the AGE queries."""
|
||||||
|
|
||||||
@@ -53,6 +62,7 @@ class AGEQueryException(Exception):
|
|||||||
return self.details
|
return self.details
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class AGEStorage(BaseGraphStorage):
|
class AGEStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -98,9 +108,6 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
|
||||||
print("KG successfully indexed.")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -396,7 +403,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:`{label}`) RETURN n
|
MATCH (n:`{label}`) RETURN n
|
||||||
@@ -454,17 +461,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_label (str): Label of the source nodes
|
|
||||||
target_node_label (str): Label of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of all relationships/edges found
|
|
||||||
"""
|
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
|
|
||||||
@@ -488,7 +485,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
"""
|
||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
:return: List of dictionaries containing edge information
|
:return: List of dictionaries containing edge information
|
||||||
@@ -526,7 +523,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((AGEQueryException,)),
|
retry=retry_if_exception_type((AGEQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the AGE database.
|
Upsert a node in the AGE database.
|
||||||
|
|
||||||
@@ -562,8 +559,8 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((AGEQueryException,)),
|
retry=retry_if_exception_type((AGEQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
|
|
||||||
@@ -619,3 +616,23 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
yield connection
|
yield connection
|
||||||
finally:
|
finally:
|
||||||
await self._driver.putconn(connection)
|
await self._driver.putconn(connection)
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# AGES handles persistence automatically
|
||||||
|
pass
|
||||||
|
@@ -1,19 +1,29 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Any, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chromadb import HttpClient, PersistentClient
|
|
||||||
from chromadb.config import Settings
|
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
if not pm.is_installed("chromadb"):
|
||||||
|
pm.install("chromadb")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from chromadb import HttpClient, PersistentClient
|
||||||
|
from chromadb.config import Settings
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`chromadb` library is not installed. Please install it via pip: `pip install chromadb`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
"""ChromaDB vector storage implementation."""
|
"""ChromaDB vector storage implementation."""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
try:
|
try:
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
@@ -102,7 +112,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.error(f"ChromaDB initialization failed: {str(e)}")
|
logger.error(f"ChromaDB initialization failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if not data:
|
if not data:
|
||||||
logger.warning("Empty data provided to vector DB")
|
logger.warning("Empty data provided to vector DB")
|
||||||
return []
|
return []
|
||||||
@@ -151,7 +161,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
@@ -183,6 +193,12 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.error(f"Error during ChromaDB query: {str(e)}")
|
logger.error(f"Error during ChromaDB query: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
# ChromaDB handles persistence automatically
|
# ChromaDB handles persistence automatically
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import faiss
|
from typing import Any, final
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
@@ -15,7 +17,19 @@ from lightrag.base import (
|
|||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not pm.is_installed("faiss"):
|
||||||
|
pm.install("faiss")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`faiss` library is not installed. Please install it via pip: `pip install faiss`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class FaissVectorDBStorage(BaseVectorStorage):
|
class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
"""
|
"""
|
||||||
@@ -23,8 +37,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Grab config values if available
|
# Grab config values if available
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
@@ -57,7 +69,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
# Attempt to load an existing index + metadata from disk
|
# Attempt to load an existing index + metadata from disk
|
||||||
self._load_faiss_index()
|
self._load_faiss_index()
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update vectors in the Faiss index.
|
Insert or update vectors in the Faiss index.
|
||||||
|
|
||||||
@@ -147,7 +159,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||||
return [m["__id__"] for m in list_data]
|
return [m["__id__"] for m in list_data]
|
||||||
|
|
||||||
async def query(self, query: str, top_k=5):
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||||
"""
|
"""
|
||||||
@@ -210,11 +222,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_entity(self, entity_name: str):
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
"""
|
|
||||||
Delete a single entity by computing its hashed ID
|
|
||||||
the same way your code does it with `compute_mdhash_id`.
|
|
||||||
"""
|
|
||||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||||
await self.delete([entity_id])
|
await self.delete([entity_id])
|
||||||
@@ -234,12 +242,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._remove_faiss_ids(relations)
|
self._remove_faiss_ids(relations)
|
||||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""
|
|
||||||
Called after indexing is done (save Faiss index + metadata).
|
|
||||||
"""
|
|
||||||
self._save_faiss_index()
|
self._save_faiss_index()
|
||||||
logger.info("Faiss index saved successfully.")
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------
|
||||||
# Internal helper methods
|
# Internal helper methods
|
||||||
|
@@ -3,11 +3,11 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, final
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from gremlin_python.driver import client, serializer
|
|
||||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
|
||||||
from gremlin_python.driver.protocol import GremlinServerError
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
@@ -15,11 +15,22 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gremlin_python.driver import client, serializer
|
||||||
|
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||||
|
from gremlin_python.driver.protocol import GremlinServerError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`gremlin` library is not installed. Please install it via pip: `pip install gremlin`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class GremlinStorage(BaseGraphStorage):
|
class GremlinStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -76,8 +87,9 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
if self._driver:
|
if self._driver:
|
||||||
self._driver.close()
|
self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
print("KG successfully indexed.")
|
# Gremlin handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_value_map(value: Any) -> str:
|
def _to_value_map(value: Any) -> str:
|
||||||
@@ -190,7 +202,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return result[0]["has_edge"]
|
return result[0]["has_edge"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
entity_name = GremlinStorage._fix_name(node_id)
|
entity_name = GremlinStorage._fix_name(node_id)
|
||||||
query = f"""g
|
query = f"""g
|
||||||
.V().has('graph', {self.graph_name})
|
.V().has('graph', {self.graph_name})
|
||||||
@@ -252,17 +264,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given names
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_id (str): Name of the source nodes
|
|
||||||
target_node_id (str): Name of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict|None: Dict of found edge properties, or None if not found
|
|
||||||
"""
|
|
||||||
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||||
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||||
query = f"""g
|
query = f"""g
|
||||||
@@ -286,11 +288,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return edge_properties
|
return edge_properties
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
|
||||||
Retrieves all edges (relationships) for a particular node identified by its name.
|
|
||||||
:return: List of tuples containing edge sources and targets
|
|
||||||
"""
|
|
||||||
node_name = GremlinStorage._fix_name(source_node_id)
|
node_name = GremlinStorage._fix_name(source_node_id)
|
||||||
query = f"""g
|
query = f"""g
|
||||||
.E()
|
.E()
|
||||||
@@ -316,7 +314,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((GremlinServerError,)),
|
retry=retry_if_exception_type((GremlinServerError,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the Gremlin graph.
|
Upsert a node in the Gremlin graph.
|
||||||
|
|
||||||
@@ -357,8 +355,8 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((GremlinServerError,)),
|
retry=retry_if_exception_type((GremlinServerError,)),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their names.
|
Upsert an edge and its properties between two nodes identified by their names.
|
||||||
|
|
||||||
@@ -397,3 +395,19 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,56 +1,6 @@
|
|||||||
"""
|
|
||||||
JsonDocStatus Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
from typing import Any, Union
|
from typing import Any, Union, final
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
@@ -64,6 +14,7 @@ from lightrag.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonDocStatusStorage(DocStatusStorage):
|
class JsonDocStatusStorage(DocStatusStorage):
|
||||||
"""JSON implementation of document status storage"""
|
"""JSON implementation of document status storage"""
|
||||||
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
return set(data) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
result: list[dict[str, Any]] = []
|
result: list[dict[str, Any]] = []
|
||||||
@@ -88,7 +39,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
counts = {status: 0 for status in DocStatus}
|
counts = {status.value: 0 for status in DocStatus}
|
||||||
for doc in self._data.values():
|
for doc in self._data.values():
|
||||||
counts[doc["status"]] += 1
|
counts[doc["status"]] += 1
|
||||||
return counts
|
return counts
|
||||||
@@ -96,23 +47,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
"""all documents with a specific status"""
|
"""Get all documents with a specific status"""
|
||||||
return {
|
return {
|
||||||
k: DocProcessingStatus(**v)
|
k: DocProcessingStatus(**v)
|
||||||
for k, v in self._data.items()
|
for k, v in self._data.items()
|
||||||
if v["status"] == status
|
if v["status"] == status.value
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Save data to file after indexing"""
|
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""Update or insert document status
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Dictionary of document IDs and their status data
|
|
||||||
"""
|
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
@@ -120,7 +65,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def delete(self, doc_ids: list[str]):
|
async def delete(self, doc_ids: list[str]):
|
||||||
"""Delete document status by IDs"""
|
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
self._data.pop(doc_id, None)
|
self._data.pop(doc_id, None)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
|
async def drop(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, final
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
@@ -13,6 +13,7 @@ from lightrag.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonKVStorage(BaseKVStorage):
|
class JsonKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -22,10 +23,10 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
@@ -38,8 +39,8 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
for id in ids
|
for id in ids
|
||||||
]
|
]
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
return set(data) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -10,17 +11,21 @@ import configparser
|
|||||||
|
|
||||||
if not pm.is_installed("pymilvus"):
|
if not pm.is_installed("pymilvus"):
|
||||||
pm.install("pymilvus")
|
pm.install("pymilvus")
|
||||||
from pymilvus import MilvusClient
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pymilvus import MilvusClient
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`."
|
||||||
|
) from e
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: MilvusClient, collection_name: str, **kwargs
|
client: MilvusClient, collection_name: str, **kwargs
|
||||||
@@ -71,7 +76,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||||||
dimension=self.embedding_func.embedding_dim,
|
dimension=self.embedding_func.embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
if not len(data):
|
if not len(data):
|
||||||
logger.warning("You insert an empty data to vector DB")
|
logger.warning("You insert an empty data to vector DB")
|
||||||
@@ -106,7 +111,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def query(self, query, top_k=5):
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
@@ -123,3 +128,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||||||
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
||||||
for dp in results[0]
|
for dp in results[0]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Milvus handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,22 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
import configparser
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
from typing import Any, List, Union, final
|
||||||
pm.install("pymongo")
|
|
||||||
|
|
||||||
if not pm.is_installed("motor"):
|
|
||||||
pm.install("motor")
|
|
||||||
|
|
||||||
from typing import Any, List, Tuple, Union
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.operations import SearchIndexModel
|
|
||||||
from pymongo.errors import PyMongoError
|
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -29,12 +18,29 @@ from ..base import (
|
|||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
if not pm.is_installed("pymongo"):
|
||||||
|
pm.install("pymongo")
|
||||||
|
|
||||||
|
if not pm.is_installed("motor"):
|
||||||
|
pm.install("motor")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`."
|
||||||
|
) from e
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -60,17 +66,17 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
# Ensure collection exists
|
# Ensure collection exists
|
||||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
return await self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
cursor = self._data.find({"_id": {"$in": ids}})
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
return await cursor.to_list()
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
|
||||||
existing_ids = {str(x["_id"]) async for x in cursor}
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
return data - existing_ids
|
return keys - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
@@ -107,11 +113,16 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Mongo handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the collection"""
|
"""Drop the collection"""
|
||||||
await self._data.drop()
|
await self._data.drop()
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -191,7 +202,12 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
for doc in result
|
for doc in result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Mongo handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoGraphStorage(BaseGraphStorage):
|
class MongoGraphStorage(BaseGraphStorage):
|
||||||
"""
|
"""
|
||||||
@@ -429,7 +445,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""
|
"""
|
||||||
Return the full node document (including "edges"), or None if missing.
|
Return the full node document (including "edges"), or None if missing.
|
||||||
"""
|
"""
|
||||||
@@ -437,11 +453,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Return the first edge dict from source_node_id to target_node_id if it exists.
|
|
||||||
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
|
|
||||||
"""
|
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{"$match": {"_id": source_node_id}},
|
{"$match": {"_id": source_node_id}},
|
||||||
{
|
{
|
||||||
@@ -467,9 +479,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
return e
|
return e
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
self, source_node_id: str
|
|
||||||
) -> Union[List[Tuple[str, str]], None]:
|
|
||||||
"""
|
"""
|
||||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||||
@@ -503,7 +513,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update a node document. If new, create an empty edges array.
|
Insert or update a node document. If new, create an empty edges array.
|
||||||
"""
|
"""
|
||||||
@@ -513,8 +523,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
||||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||||
@@ -540,7 +550,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
1) Remove node's doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound edges from any doc that references node_id.
|
||||||
@@ -557,7 +567,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
"""
|
"""
|
||||||
Placeholder for demonstration, raises NotImplementedError.
|
Placeholder for demonstration, raises NotImplementedError.
|
||||||
"""
|
"""
|
||||||
@@ -759,11 +771,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Mongo handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoVectorDBStorage(BaseVectorStorage):
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
@@ -828,7 +843,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||||||
except PyMongoError as _:
|
except PyMongoError as _:
|
||||||
logger.debug("vector index already exist")
|
logger.debug("vector index already exist")
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
logger.warning("You are inserting an empty data set to vector DB")
|
logger.warning("You are inserting an empty data set to vector DB")
|
||||||
@@ -871,7 +886,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
async def query(self, query, top_k=5):
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
"""Queries the vector database using Atlas Vector Search."""
|
"""Queries the vector database using Atlas Vector Search."""
|
||||||
# Generate the embedding
|
# Generate the embedding
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
@@ -905,6 +920,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||||||
for doc in results
|
for doc in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Mongo handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||||
"""Check if the collection exists. if not, create it."""
|
"""Check if the collection exists. if not, create it."""
|
||||||
|
@@ -1,80 +1,35 @@
|
|||||||
"""
|
|
||||||
NanoVectorDB Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("nano-vectordb"):
|
|
||||||
pm.install("nano-vectordb")
|
|
||||||
|
|
||||||
from nano_vectordb import NanoVectorDB
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
)
|
)
|
||||||
|
import pipmaster as pm
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not pm.is_installed("nano-vectordb"):
|
||||||
|
pm.install("nano-vectordb")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nano_vectordb import NanoVectorDB
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize lock only for file operations
|
# Initialize lock only for file operations
|
||||||
self._save_lock = asyncio.Lock()
|
self._save_lock = asyncio.Lock()
|
||||||
@@ -95,7 +50,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
if not len(data):
|
if not len(data):
|
||||||
logger.warning("You insert an empty data to vector DB")
|
logger.warning("You insert an empty data to vector DB")
|
||||||
@@ -139,7 +94,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, query: str, top_k=5):
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
results = self._client.query(
|
results = self._client.query(
|
||||||
@@ -176,7 +131,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||||
|
|
||||||
async def delete_entity(self, entity_name: str):
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -211,7 +166,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
# Protect file write operation
|
|
||||||
async with self._save_lock:
|
async with self._save_lock:
|
||||||
self._client.save()
|
self._client.save()
|
||||||
|
@@ -3,20 +3,11 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, Tuple, List, Dict
|
from typing import Any, List, Dict, final
|
||||||
import pipmaster as pm
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
|
||||||
pm.install("neo4j")
|
|
||||||
|
|
||||||
from neo4j import (
|
|
||||||
AsyncGraphDatabase,
|
|
||||||
exceptions as neo4jExceptions,
|
|
||||||
AsyncDriver,
|
|
||||||
AsyncManagedTransaction,
|
|
||||||
GraphDatabase,
|
|
||||||
)
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
@@ -27,12 +18,29 @@ from tenacity import (
|
|||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
if not pm.is_installed("neo4j"):
|
||||||
|
pm.install("neo4j")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from neo4j import (
|
||||||
|
AsyncGraphDatabase,
|
||||||
|
exceptions as neo4jExceptions,
|
||||||
|
AsyncDriver,
|
||||||
|
AsyncManagedTransaction,
|
||||||
|
GraphDatabase,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`neo4j` library is not installed. Please install it via pip: `pip install neo4j`."
|
||||||
|
) from e
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -140,8 +148,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
print("KG successfully indexed.")
|
# Noe4J handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
async def _label_exists(self, label: str) -> bool:
|
async def _label_exists(self, label: str) -> bool:
|
||||||
"""Check if a label exists in the Neo4j database."""
|
"""Check if a label exists in the Neo4j database."""
|
||||||
@@ -191,7 +200,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier.
|
"""Get node by its label identifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -252,17 +261,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""Find edge between two nodes identified by their labels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_id (str): Label of the source node
|
|
||||||
target_node_id (str): Label of the target node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Edge properties if found, with at least {"weight": 0.0}
|
|
||||||
None: If error occurs
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
@@ -321,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
# Return default edge properties on error
|
# Return default edge properties on error
|
||||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -364,7 +363,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the Neo4j database.
|
Upsert a node in the Neo4j database.
|
||||||
|
|
||||||
@@ -405,8 +404,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
|
|
||||||
@@ -603,7 +602,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
await traverse(label, 0)
|
await traverse(label, 0)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_all_labels(self) -> List[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Get all existing node labels in the database
|
Get all existing node labels in the database
|
||||||
Returns:
|
Returns:
|
||||||
@@ -627,3 +626,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async for record in result:
|
async for record in result:
|
||||||
labels.append(record["label"])
|
labels.append(record["label"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,61 +1,12 @@
|
|||||||
"""
|
|
||||||
NetworkX Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, cast
|
from typing import Any, cast, final
|
||||||
import networkx as nx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
@@ -64,7 +15,15 @@ from lightrag.base import (
|
|||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import networkx as nx
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`networkx` library is not installed. Please install it via pip: `pip install networkx`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NetworkXStorage(BaseGraphStorage):
|
class NetworkXStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -142,7 +101,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
@@ -151,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
return self._graph.has_edge(source_node_id, target_node_id)
|
return self._graph.has_edge(source_node_id, target_node_id)
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
return self._graph.nodes.get(node_id)
|
return self._graph.nodes.get(node_id)
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
@@ -162,35 +121,32 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
return self._graph.edges.get((source_node_id, target_node_id))
|
return self._graph.edges.get((source_node_id, target_node_id))
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str):
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
if self._graph.has_node(source_node_id):
|
if self._graph.has_node(source_node_id):
|
||||||
return list(self._graph.edges(source_node_id))
|
return list(self._graph.edges(source_node_id))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
self._graph.add_node(node_id, **node_data)
|
self._graph.add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
|
||||||
Delete a node from the graph based on the specified node_id.
|
|
||||||
|
|
||||||
:param node_id: The node_id to delete
|
|
||||||
"""
|
|
||||||
if self._graph.has_node(node_id):
|
if self._graph.has_node(node_id):
|
||||||
self._graph.remove_node(node_id)
|
self._graph.remove_node(node_id)
|
||||||
logger.info(f"Node {node_id} deleted from the graph.")
|
logger.info(f"Node {node_id} deleted from the graph.")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
if algorithm not in self._node_embed_algorithms:
|
if algorithm not in self._node_embed_algorithms:
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
return await self._node_embed_algorithms[algorithm]()
|
||||||
@@ -226,3 +182,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
if self._graph.has_edge(source, target):
|
if self._graph.has_edge(source, target):
|
||||||
self._graph.remove_edge(source, target)
|
self._graph.remove_edge(source, target)
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -4,16 +4,11 @@ import asyncio
|
|||||||
# import html
|
# import html
|
||||||
# import os
|
# import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("oracledb"):
|
from lightrag.types import KnowledgeGraph
|
||||||
pm.install("oracledb")
|
|
||||||
|
|
||||||
|
|
||||||
import oracledb
|
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -23,6 +18,19 @@ from ..base import (
|
|||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
if not pm.is_installed("oracledb"):
|
||||||
|
pm.install("oracledb")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import oracledb
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`oracledb` library is not installed. Please install it via pip: `pip install oracledb`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class OracleDB:
|
class OracleDB:
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
@@ -169,6 +177,7 @@ class OracleDB:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleKVStorage(BaseKVStorage):
|
class OracleKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -181,7 +190,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
"""Get doc_full data based on id."""
|
"""Get doc_full data based on id."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
@@ -232,7 +241,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
res = [{k: v} for k, v in dict_res.items()]
|
res = [{k: v} for k, v in dict_res.items()]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
table_name=namespace_to_table_name(self.namespace),
|
table_name=namespace_to_table_name(self.namespace),
|
||||||
@@ -248,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
return set(keys)
|
return set(keys)
|
||||||
|
|
||||||
################ INSERT METHODS ################
|
################ INSERT METHODS ################
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
list_data = [
|
list_data = [
|
||||||
{
|
{
|
||||||
@@ -307,20 +316,17 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
if is_namespace(
|
# Oracle handles persistence automatically
|
||||||
self.namespace,
|
pass
|
||||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
|
||||||
):
|
async def drop(self) -> None:
|
||||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: OracleDB
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
@@ -330,16 +336,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
|
||||||
"""向向量数据库中插入数据"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
"""从向量数据库中查询数据"""
|
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
# 转换精度
|
# 转换精度
|
||||||
@@ -359,21 +357,29 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
# print("vector search result:",results)
|
# print("vector search result:",results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Oracles handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleGraphStorage(BaseGraphStorage):
|
class OracleGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: OracleDB
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""从graphml文件加载图"""
|
|
||||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||||
|
|
||||||
#################### insert method ################
|
#################### insert method ################
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""插入或更新节点"""
|
|
||||||
# print("go into upsert node method")
|
|
||||||
entity_name = node_id
|
entity_name = node_id
|
||||||
entity_type = node_data["entity_type"]
|
entity_type = node_data["entity_type"]
|
||||||
description = node_data["description"]
|
description = node_data["description"]
|
||||||
@@ -406,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""插入或更新边"""
|
"""插入或更新边"""
|
||||||
# print("go into upsert edge method")
|
# print("go into upsert edge method")
|
||||||
source_name = source_node_id
|
source_name = source_node_id
|
||||||
@@ -446,8 +452,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(
|
||||||
"""为节点生成向量"""
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
if algorithm not in self._node_embed_algorithms:
|
if algorithm not in self._node_embed_algorithms:
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
return await self._node_embed_algorithms[algorithm]()
|
||||||
@@ -464,11 +471,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||||
return embeddings, nodes_ids
|
return embeddings, nodes_ids
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""写入graphhml图文件"""
|
# Oracles handles persistence automatically
|
||||||
logger.info(
|
pass
|
||||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
|
||||||
)
|
|
||||||
|
|
||||||
#################### query method #################
|
#################### query method #################
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
@@ -486,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""根据源和目标节点id检查边是否存在"""
|
|
||||||
SQL = SQL_TEMPLATES["has_edge"]
|
SQL = SQL_TEMPLATES["has_edge"]
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
@@ -503,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""根据节点id获取节点的度"""
|
|
||||||
SQL = SQL_TEMPLATES["node_degree"]
|
SQL = SQL_TEMPLATES["node_degree"]
|
||||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
@@ -521,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
# print("Edge degree",degree)
|
# print("Edge degree",degree)
|
||||||
return degree
|
return degree
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""根据节点id获取节点数据"""
|
"""根据节点id获取节点数据"""
|
||||||
SQL = SQL_TEMPLATES["get_node"]
|
SQL = SQL_TEMPLATES["get_node"]
|
||||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
@@ -537,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""根据源和目标节点id获取边"""
|
|
||||||
SQL = SQL_TEMPLATES["get_edge"]
|
SQL = SQL_TEMPLATES["get_edge"]
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
@@ -553,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str):
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""根据节点id获取节点的所有边"""
|
|
||||||
if await self.has_node(source_node_id):
|
if await self.has_node(source_node_id):
|
||||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||||
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||||
@@ -590,6 +591,17 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
if res:
|
if res:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -4,24 +4,19 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Set, Tuple, Union
|
from typing import Any, Dict, List, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("asyncpg"):
|
from lightrag.types import KnowledgeGraph
|
||||||
pm.install("asyncpg")
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -39,6 +34,20 @@ if sys.platform.startswith("win"):
|
|||||||
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
if not pm.is_installed("asyncpg"):
|
||||||
|
pm.install("asyncpg")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncpg
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLDB:
|
class PostgreSQLDB:
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
@@ -175,6 +184,7 @@ class PostgreSQLDB:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGKVStorage(BaseKVStorage):
|
class PGKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -185,7 +195,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
"""Get doc_full data by id."""
|
"""Get doc_full data by id."""
|
||||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
@@ -240,7 +250,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
params = {"workspace": self.db.workspace, "status": status}
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
return await self.db.query(SQL, params, multirows=True)
|
return await self.db.query(SQL, params, multirows=True)
|
||||||
|
|
||||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||||
table_name=namespace_to_table_name(self.namespace),
|
table_name=namespace_to_table_name(self.namespace),
|
||||||
@@ -261,7 +271,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
################ INSERT METHODS ################
|
################ INSERT METHODS ################
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
pass
|
pass
|
||||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||||
@@ -287,20 +297,17 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
if is_namespace(
|
# PG handles persistence automatically
|
||||||
self.namespace,
|
pass
|
||||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
|
||||||
):
|
async def drop(self) -> None:
|
||||||
logger.info("full doc and chunk data had been saved into postgresql db!")
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: PostgreSQLDB
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
@@ -352,7 +359,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
async def upsert(self, data: Dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
if not len(data):
|
if not len(data):
|
||||||
logger.warning("You insert an empty data to vector DB")
|
logger.warning("You insert an empty data to vector DB")
|
||||||
@@ -398,12 +405,8 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
await self.db.execute(upsert_sql, data)
|
await self.db.execute(upsert_sql, data)
|
||||||
|
|
||||||
async def index_done_callback(self):
|
|
||||||
logger.info("vector data had been saved into postgresql db!")
|
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
"""从向量数据库中查询数据"""
|
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
embedding_string = ",".join(map(str, embedding))
|
embedding_string = ",".join(map(str, embedding))
|
||||||
@@ -417,23 +420,31 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
results = await self.db.query(sql, params=params, multirows=True)
|
results = await self.db.query(sql, params=params, multirows=True)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# PG handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGDocStatusStorage(DocStatusStorage):
|
class PGDocStatusStorage(DocStatusStorage):
|
||||||
# db instance must be injected before use
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
# db: PostgreSQLDB
|
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
keys = ",".join([f"'{_id}'" for _id in data])
|
keys = ",".join([f"'{_id}'" for _id in keys])
|
||||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||||
result = await self.db.query(sql, multirows=True)
|
result = await self.db.query(sql, multirows=True)
|
||||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||||
if result is None:
|
if result is None:
|
||||||
return set(data)
|
return set(keys)
|
||||||
else:
|
else:
|
||||||
existed = set([element["id"] for element in result])
|
existed = set([element["id"] for element in result])
|
||||||
return set(data) - existed
|
return set(keys) - existed
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||||
@@ -452,6 +463,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
updated_at=result[0]["updated_at"],
|
updated_at=result[0]["updated_at"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_status_counts(self) -> Dict[str, int]:
|
async def get_status_counts(self) -> Dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||||
@@ -470,7 +484,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
) -> Dict[str, DocProcessingStatus]:
|
) -> Dict[str, DocProcessingStatus]:
|
||||||
"""all documents with a specific status"""
|
"""all documents with a specific status"""
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
||||||
params = {"workspace": self.db.workspace, "status": status}
|
params = {"workspace": self.db.workspace, "status": status.value}
|
||||||
result = await self.db.query(sql, params, True)
|
result = await self.db.query(sql, params, True)
|
||||||
return {
|
return {
|
||||||
element["id"]: DocProcessingStatus(
|
element["id"]: DocProcessingStatus(
|
||||||
@@ -485,11 +499,11 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
for element in result
|
for element in result
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
# PG handles persistence automatically
|
||||||
logger.info("Doc status had been saved into postgresql db!")
|
pass
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""Update or insert document status
|
"""Update or insert document status
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -520,31 +534,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def update_doc_status(self, data: dict[str, dict]) -> None:
|
async def drop(self) -> None:
|
||||||
"""
|
raise NotImplementedError
|
||||||
Updates only the document status, chunk count, and updated timestamp.
|
|
||||||
|
|
||||||
This method ensures that only relevant fields are updated instead of overwriting
|
|
||||||
the entire document record. If `updated_at` is not provided, the database will
|
|
||||||
automatically use the current timestamp.
|
|
||||||
"""
|
|
||||||
sql = """
|
|
||||||
UPDATE LIGHTRAG_DOC_STATUS
|
|
||||||
SET status = $3,
|
|
||||||
chunks_count = $4,
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
WHERE workspace = $1 AND id = $2
|
|
||||||
"""
|
|
||||||
for k, v in data.items():
|
|
||||||
_data = {
|
|
||||||
"workspace": self.db.workspace,
|
|
||||||
"id": k,
|
|
||||||
"status": v["status"].value, # Convert Enum to string
|
|
||||||
"chunks_count": v.get(
|
|
||||||
"chunks_count", -1
|
|
||||||
), # Default to -1 if not provided
|
|
||||||
}
|
|
||||||
await self.db.execute(sql, _data)
|
|
||||||
|
|
||||||
|
|
||||||
class PGGraphQueryException(Exception):
|
class PGGraphQueryException(Exception):
|
||||||
@@ -565,11 +556,9 @@ class PGGraphQueryException(Exception):
|
|||||||
return self.details
|
return self.details
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGGraphStorage(BaseGraphStorage):
|
class PGGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: PostgreSQLDB
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
print("no preloading of graph with AGE in production")
|
print("no preloading of graph with AGE in production")
|
||||||
@@ -580,8 +569,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
print("KG successfully indexed.")
|
# PG handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
||||||
@@ -811,7 +801,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:Entity {node_id: "%s"})
|
MATCH (n:Entity {node_id: "%s"})
|
||||||
@@ -866,17 +856,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_id (str): Label of the source nodes
|
|
||||||
target_node_id (str): Label of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of all relationships/edges found
|
|
||||||
"""
|
|
||||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||||
|
|
||||||
@@ -900,7 +880,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
"""
|
||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
:return: List of dictionaries containing edge information
|
:return: List of dictionaries containing edge information
|
||||||
@@ -948,14 +928,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
|
||||||
Upsert a node in the AGE database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The unique identifier for the node (used as label)
|
|
||||||
node_data: Dictionary of node properties
|
|
||||||
"""
|
|
||||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||||
properties = node_data
|
properties = node_data
|
||||||
|
|
||||||
@@ -986,8 +959,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
|
|
||||||
@@ -1029,6 +1002,22 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
NAMESPACE_TABLE_MAP = {
|
NAMESPACE_TABLE_MAP = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -7,16 +8,24 @@ import hashlib
|
|||||||
import uuid
|
import uuid
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
import pipmaster as pm
|
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("qdrant_client"):
|
if not pm.is_installed("qdrant_client"):
|
||||||
pm.install("qdrant_client")
|
pm.install("qdrant_client")
|
||||||
|
|
||||||
from qdrant_client import QdrantClient, models
|
try:
|
||||||
|
from qdrant_client import QdrantClient, models
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
except ImportError:
|
||||||
config.read("config.ini", "utf-8")
|
raise ImportError(
|
||||||
|
"`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id_for_qdrant(
|
def compute_mdhash_id_for_qdrant(
|
||||||
@@ -47,10 +56,9 @@ def compute_mdhash_id_for_qdrant(
|
|||||||
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: QdrantClient, collection_name: str, **kwargs
|
client: QdrantClient, collection_name: str, **kwargs
|
||||||
@@ -85,7 +93,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if not len(data):
|
if not len(data):
|
||||||
logger.warning("You insert an empty data to vector DB")
|
logger.warning("You insert an empty data to vector DB")
|
||||||
return []
|
return []
|
||||||
@@ -130,7 +138,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def query(self, query, top_k=5):
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
@@ -143,3 +151,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.debug(f"query result: {results}")
|
logger.debug(f"query result: {results}")
|
||||||
|
|
||||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Qdrant handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Union
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -19,6 +19,7 @@ config = configparser.ConfigParser()
|
|||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class RedisKVStorage(BaseKVStorage):
|
class RedisKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -28,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||||
logger.info(f"Use Redis as KV {self.namespace}")
|
logger.info(f"Use Redis as KV {self.namespace}")
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||||
return json.loads(data) if data else None
|
return json.loads(data) if data else None
|
||||||
|
|
||||||
@@ -39,16 +40,16 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
return [json.loads(result) if result else None for result in results]
|
return [json.loads(result) if result else None for result in results]
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
pipe = self._redis.pipeline()
|
pipe = self._redis.pipeline()
|
||||||
for key in data:
|
for key in keys:
|
||||||
pipe.exists(f"{self.namespace}:{key}")
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
|
||||||
return set(data) - existing_ids
|
return set(keys) - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
pipe = self._redis.pipeline()
|
pipe = self._redis.pipeline()
|
||||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||||
@@ -61,3 +62,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||||
if keys:
|
if keys:
|
||||||
await self._redis.delete(*keys)
|
await self._redis.delete(*keys)
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Redis handles persistence automatically
|
||||||
|
pass
|
||||||
|
@@ -1,9 +1,18 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||||
|
from ..namespace import NameSpace, is_namespace
|
||||||
|
from ..utils import logger
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("pymysql"):
|
if not pm.is_installed("pymysql"):
|
||||||
@@ -11,12 +20,13 @@ if not pm.is_installed("pymysql"):
|
|||||||
if not pm.is_installed("sqlalchemy"):
|
if not pm.is_installed("sqlalchemy"):
|
||||||
pm.install("sqlalchemy")
|
pm.install("sqlalchemy")
|
||||||
|
|
||||||
from sqlalchemy import create_engine, text
|
try:
|
||||||
from tqdm import tqdm
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
except ImportError as e:
|
||||||
from ..namespace import NameSpace, is_namespace
|
raise ImportError(
|
||||||
from ..utils import logger
|
"`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class TiDB:
|
class TiDB:
|
||||||
@@ -99,6 +109,7 @@ class TiDB:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBKVStorage(BaseKVStorage):
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -110,7 +121,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
"""Fetch doc_full data by id."""
|
"""Fetch doc_full data by id."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"id": id}
|
params = {"id": id}
|
||||||
@@ -125,8 +136,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
)
|
)
|
||||||
return await self.db.query(SQL, multirows=True)
|
return await self.db.query(SQL, multirows=True)
|
||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""过滤掉重复内容"""
|
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
table_name=namespace_to_table_name(self.namespace),
|
table_name=namespace_to_table_name(self.namespace),
|
||||||
id_field=namespace_to_id(self.namespace),
|
id_field=namespace_to_id(self.namespace),
|
||||||
@@ -147,7 +157,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
################ INSERT full_doc AND chunks ################
|
################ INSERT full_doc AND chunks ################
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
self._data.update(left_data)
|
self._data.update(left_data)
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
@@ -200,20 +210,17 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
return left_data
|
return left_data
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
if is_namespace(
|
# Ti handles persistence automatically
|
||||||
self.namespace,
|
pass
|
||||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
|
||||||
):
|
async def drop(self) -> None:
|
||||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: TiDB
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
@@ -227,7 +234,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
"""Search from tidb vector"""
|
"""Search from tidb vector"""
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
@@ -249,7 +256,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
###### INSERT entities And relationships ######
|
###### INSERT entities And relationships ######
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
# ignore, upsert in TiDBKVStorage already
|
# ignore, upsert in TiDBKVStorage already
|
||||||
if not len(data):
|
if not len(data):
|
||||||
logger.warning("You insert an empty data to vector DB")
|
logger.warning("You insert an empty data to vector DB")
|
||||||
@@ -332,7 +339,18 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
params = {"workspace": self.db.workspace, "status": status}
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
return await self.db.query(SQL, params, multirows=True)
|
return await self.db.query(SQL, params, multirows=True)
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Ti handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -342,7 +360,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
#################### upsert method ################
|
#################### upsert method ################
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
entity_name = node_id
|
entity_name = node_id
|
||||||
entity_type = node_data["entity_type"]
|
entity_type = node_data["entity_type"]
|
||||||
description = node_data["description"]
|
description = node_data["description"]
|
||||||
@@ -373,7 +391,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
source_name = source_node_id
|
source_name = source_node_id
|
||||||
target_name = target_node_id
|
target_name = target_node_id
|
||||||
weight = edge_data["weight"]
|
weight = edge_data["weight"]
|
||||||
@@ -409,7 +427,9 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
}
|
}
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
if algorithm not in self._node_embed_algorithms:
|
if algorithm not in self._node_embed_algorithms:
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
return await self._node_embed_algorithms[algorithm]()
|
||||||
@@ -442,14 +462,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||||
return degree
|
return degree
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
sql = SQL_TEMPLATES["get_node"]
|
sql = SQL_TEMPLATES["get_node"]
|
||||||
param = {"name": node_id, "workspace": self.db.workspace}
|
param = {"name": node_id, "workspace": self.db.workspace}
|
||||||
return await self.db.query(sql, param)
|
return await self.db.query(sql, param)
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
sql = SQL_TEMPLATES["get_edge"]
|
sql = SQL_TEMPLATES["get_edge"]
|
||||||
param = {
|
param = {
|
||||||
"source_name": source_node_id,
|
"source_name": source_node_id,
|
||||||
@@ -458,9 +478,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
}
|
}
|
||||||
return await self.db.query(sql, param)
|
return await self.db.query(sql, param)
|
||||||
|
|
||||||
async def get_node_edges(
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
self, source_node_id: str
|
|
||||||
) -> Union[list[tuple[str, str]], None]:
|
|
||||||
sql = SQL_TEMPLATES["get_node_edges"]
|
sql = SQL_TEMPLATES["get_node_edges"]
|
||||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||||
res = await self.db.query(sql, param, multirows=True)
|
res = await self.db.query(sql, param, multirows=True)
|
||||||
@@ -470,6 +488,21 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Ti handles persistence automatically
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -674,7 +674,7 @@ class LightRAG:
|
|||||||
"content": content,
|
"content": content,
|
||||||
"content_summary": self._get_content_summary(content),
|
"content_summary": self._get_content_summary(content),
|
||||||
"content_length": len(content),
|
"content_length": len(content),
|
||||||
"status": DocStatus.PENDING,
|
"status": DocStatus.PENDING.value,
|
||||||
"created_at": datetime.now().isoformat(),
|
"created_at": datetime.now().isoformat(),
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
@@ -745,7 +745,7 @@ class LightRAG:
|
|||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSING,
|
"status": DocStatus.PROCESSING.value,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
"content": status_doc.content,
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
@@ -779,10 +779,10 @@ class LightRAG:
|
|||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
await self.doc_status.update_doc_status(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSED,
|
"status": DocStatus.PROCESSED.value,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
"content": status_doc.content,
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
@@ -796,10 +796,10 @@ class LightRAG:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||||
await self.doc_status.update_doc_status(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.FAILED,
|
"status": DocStatus.FAILED.value,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"content": status_doc.content,
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
|
Reference in New Issue
Block a user