Merge pull request #795 from YanSte/make-clear-what-implemented-or-not

Enhancing ABC Enforcement and Standardizing Subclass Implementations
This commit is contained in:
zrguo
2025-02-17 16:03:22 +08:00
committed by GitHub
18 changed files with 622 additions and 611 deletions

View File

@@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from enum import StrEnum
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import ( from typing import (
Any, Any,
Literal, Literal,
@@ -82,138 +83,130 @@ class QueryParam:
@dataclass @dataclass
class StorageNameSpace: class StorageNameSpace(ABC):
namespace: str namespace: str
global_config: dict[str, Any] global_config: dict[str, Any]
@abstractmethod
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing""" """Commit the storage operations after indexing"""
pass
@dataclass @dataclass
class BaseVectorStorage(StorageNameSpace): class BaseVectorStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
@abstractmethod
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
raise NotImplementedError """Query the vector storage and retrieve top_k results."""
@abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Use 'content' field from value for embedding, use key as id. """Insert or update vectors in the storage."""
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
@abstractmethod
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name""" """Delete a single entity by its name."""
raise NotImplementedError
@abstractmethod
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity."""
raise NotImplementedError
@dataclass @dataclass
class BaseKVStorage(StorageNameSpace): class BaseKVStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc | None = None embedding_func: EmbeddingFunc
@abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
raise NotImplementedError """Get value by id"""
@abstractmethod
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
raise NotImplementedError """Get values by ids"""
async def filter_keys(self, data: set[str]) -> set[str]: @abstractmethod
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return un-exist keys""" """Return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, Any]) -> None: @abstractmethod
raise NotImplementedError async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Upsert data"""
@abstractmethod
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError """Drop the storage"""
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc | None = None embedding_func: EmbeddingFunc
"""Check if a node exists in the graph."""
@abstractmethod
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
"""Check if an edge exists in the graph.""" """Check if an edge exists in the graph."""
@abstractmethod
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
"""Get the degree of a node.""" """Get the degree of a node."""
@abstractmethod
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
"""Get the degree of an edge.""" """Get the degree of an edge."""
@abstractmethod
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
"""Get a node by its id.""" """Get a node by its id."""
@abstractmethod
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError
"""Get an edge by its source and target node ids.""" """Get an edge by its source and target node ids."""
@abstractmethod
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
raise NotImplementedError
"""Get all edges connected to a node.""" """Get all edges connected to a node."""
@abstractmethod
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
raise NotImplementedError
"""Upsert a node into the graph.""" """Upsert a node into the graph."""
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError
"""Upsert an edge into the graph.""" """Upsert an edge into the graph."""
@abstractmethod
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
raise NotImplementedError
"""Delete a node from the graph.""" """Delete a node from the graph."""
@abstractmethod
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
"""Embed nodes using an algorithm.""" """Embed nodes using an algorithm."""
@abstractmethod
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph.""" """Get all labels in the graph."""
@abstractmethod
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError
"""Get a knowledge graph of a node.""" """Get a knowledge graph of a node."""
@abstractmethod
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError """Retrieve a subgraph of the knowledge graph starting from a given node."""
class DocStatus(str, Enum): class DocStatus(StrEnum):
"""Document processing status enum""" """Document processing status"""
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"
@@ -245,19 +238,16 @@ class DocProcessingStatus:
"""Additional metadata""" """Additional metadata"""
class DocStatusStorage(BaseKVStorage): @dataclass
class DocStatusStorage(BaseKVStorage, ABC):
"""Base class for document status storage""" """Base class for document status storage"""
@abstractmethod
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
raise NotImplementedError
@abstractmethod
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert."""
await self.upsert(data)

View File

@@ -5,19 +5,11 @@ import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np
import pipmaster as pm import pipmaster as pm
from lightrag.types import KnowledgeGraph
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception_type, retry_if_exception_type,
@@ -35,6 +27,23 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
try:
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
except ImportError:
raise ImportError(
"`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`."
)
class AGEQueryException(Exception): class AGEQueryException(Exception):
"""Exception for the AGE queries.""" """Exception for the AGE queries."""
@@ -53,6 +62,7 @@ class AGEQueryException(Exception):
return self.details return self.details
@final
@dataclass @dataclass
class AGEStorage(BaseGraphStorage): class AGEStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -98,9 +108,6 @@ class AGEStorage(BaseGraphStorage):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
@staticmethod @staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
""" """
@@ -396,7 +403,7 @@ class AGEStorage(BaseGraphStorage):
) )
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
query = """ query = """
MATCH (n:`{label}`) RETURN n MATCH (n:`{label}`) RETURN n
@@ -454,17 +461,7 @@ class AGEStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Find all edges between nodes of two given labels
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
@@ -488,7 +485,7 @@ class AGEStorage(BaseGraphStorage):
) )
return result return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
@@ -526,7 +523,7 @@ class AGEStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)), retry=retry_if_exception_type((AGEQueryException,)),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the AGE database. Upsert a node in the AGE database.
@@ -562,8 +559,8 @@ class AGEStorage(BaseGraphStorage):
retry=retry_if_exception_type((AGEQueryException,)), retry=retry_if_exception_type((AGEQueryException,)),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -619,3 +616,23 @@ class AGEStorage(BaseGraphStorage):
yield connection yield connection
finally: finally:
await self._driver.putconn(connection) await self._driver.putconn(connection)
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
async def index_done_callback(self) -> None:
# AGES handles persistence automatically
pass

View File

@@ -1,19 +1,29 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Any, final
import numpy as np import numpy as np
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from lightrag.utils import logger from lightrag.utils import logger
import pipmaster as pm
if not pm.is_installed("chromadb"):
pm.install("chromadb")
try:
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
except ImportError as e:
raise ImportError(
"`chromadb` library is not installed. Please install it via pip: `pip install chromadb`."
) from e
@final
@dataclass @dataclass
class ChromaVectorDBStorage(BaseVectorStorage): class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation.""" """ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
try: try:
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -102,7 +112,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"ChromaDB initialization failed: {str(e)}") logger.error(f"ChromaDB initialization failed: {str(e)}")
raise raise
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data: if not data:
logger.warning("Empty data provided to vector DB") logger.warning("Empty data provided to vector DB")
return [] return []
@@ -151,7 +161,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}") logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise raise
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
try: try:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
@@ -183,6 +193,12 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB query: {str(e)}") logger.error(f"Error during ChromaDB query: {str(e)}")
raise raise
async def index_done_callback(self): async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically # ChromaDB handles persistence automatically
pass pass
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError

View File

@@ -1,11 +1,13 @@
import os import os
import time import time
import asyncio import asyncio
import faiss from typing import Any, final
import json import json
import numpy as np import numpy as np
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
@@ -15,7 +17,19 @@ from lightrag.base import (
BaseVectorStorage, BaseVectorStorage,
) )
if not pm.is_installed("faiss"):
pm.install("faiss")
try:
import faiss
from tqdm.asyncio import tqdm as tqdm_async
except ImportError as e:
raise ImportError(
"`faiss` library is not installed. Please install it via pip: `pip install faiss`."
) from e
@final
@dataclass @dataclass
class FaissVectorDBStorage(BaseVectorStorage): class FaissVectorDBStorage(BaseVectorStorage):
""" """
@@ -23,8 +37,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
""" """
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Grab config values if available # Grab config values if available
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -57,7 +69,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Attempt to load an existing index + metadata from disk # Attempt to load an existing index + metadata from disk
self._load_faiss_index() self._load_faiss_index()
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -147,7 +159,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """
@@ -210,11 +222,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
) )
async def delete_entity(self, entity_name: str): async def delete_entity(self, entity_name: str) -> None:
"""
Delete a single entity by computing its hashed ID
the same way your code does it with `compute_mdhash_id`.
"""
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
await self.delete([entity_id]) await self.delete([entity_id])
@@ -234,12 +242,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._remove_faiss_ids(relations) self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}") logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""
Called after indexing is done (save Faiss index + metadata).
"""
self._save_faiss_index() self._save_faiss_index()
logger.info("Faiss index saved successfully.")
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
# Internal helper methods # Internal helper methods

View File

@@ -3,11 +3,11 @@ import inspect
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, final
import numpy as np
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception_type, retry_if_exception_type,
@@ -15,11 +15,22 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from lightrag.types import KnowledgeGraph
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
try:
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
except ImportError as e:
raise ImportError(
"`gremlin` library is not installed. Please install it via pip: `pip install gremlin`."
) from e
@final
@dataclass @dataclass
class GremlinStorage(BaseGraphStorage): class GremlinStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -76,8 +87,9 @@ class GremlinStorage(BaseGraphStorage):
if self._driver: if self._driver:
self._driver.close() self._driver.close()
async def index_done_callback(self): async def index_done_callback(self) -> None:
print("KG successfully indexed.") # Gremlin handles persistence automatically
pass
@staticmethod @staticmethod
def _to_value_map(value: Any) -> str: def _to_value_map(value: Any) -> str:
@@ -190,7 +202,7 @@ class GremlinStorage(BaseGraphStorage):
return result[0]["has_edge"] return result[0]["has_edge"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name = GremlinStorage._fix_name(node_id) entity_name = GremlinStorage._fix_name(node_id)
query = f"""g query = f"""g
.V().has('graph', {self.graph_name}) .V().has('graph', {self.graph_name})
@@ -252,17 +264,7 @@ class GremlinStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Find all edges between nodes of two given names
Args:
source_node_id (str): Name of the source nodes
target_node_id (str): Name of the target nodes
Returns:
dict|None: Dict of found edge properties, or None if not found
"""
entity_name_source = GremlinStorage._fix_name(source_node_id) entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g query = f"""g
@@ -286,11 +288,7 @@ class GremlinStorage(BaseGraphStorage):
) )
return edge_properties return edge_properties
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its name.
:return: List of tuples containing edge sources and targets
"""
node_name = GremlinStorage._fix_name(source_node_id) node_name = GremlinStorage._fix_name(source_node_id)
query = f"""g query = f"""g
.E() .E()
@@ -316,7 +314,7 @@ class GremlinStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)), retry=retry_if_exception_type((GremlinServerError,)),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Gremlin graph. Upsert a node in the Gremlin graph.
@@ -357,8 +355,8 @@ class GremlinStorage(BaseGraphStorage):
retry=retry_if_exception_type((GremlinServerError,)), retry=retry_if_exception_type((GremlinServerError,)),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their names. Upsert an edge and its properties between two nodes identified by their names.
@@ -397,3 +395,19 @@ class GremlinStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -1,56 +1,6 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
from dataclasses import dataclass from dataclasses import dataclass
import os import os
from typing import Any, Union from typing import Any, Union, final
from lightrag.base import ( from lightrag.base import (
DocProcessingStatus, DocProcessingStatus,
@@ -64,6 +14,7 @@ from lightrag.utils import (
) )
@final
@dataclass @dataclass
class JsonDocStatusStorage(DocStatusStorage): class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage""" """JSON implementation of document status storage"""
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
self._data: dict[str, Any] = load_json(self._file_name) or {} self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records") logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""
return set(data) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] result: list[dict[str, Any]] = []
@@ -88,7 +39,7 @@ class JsonDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
counts = {status: 0 for status in DocStatus} counts = {status.value: 0 for status in DocStatus}
for doc in self._data.values(): for doc in self._data.values():
counts[doc["status"]] += 1 counts[doc["status"]] += 1
return counts return counts
@@ -96,23 +47,17 @@ class JsonDocStatusStorage(DocStatusStorage):
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status""" """Get all documents with a specific status"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
if v["status"] == status if v["status"] == status.value
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""Save data to file after indexing"""
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data) self._data.update(data)
await self.index_done_callback() await self.index_done_callback()
@@ -120,7 +65,9 @@ class JsonDocStatusStorage(DocStatusStorage):
return self._data.get(id) return self._data.get(id)
async def delete(self, doc_ids: list[str]): async def delete(self, doc_ids: list[str]):
"""Delete document status by IDs"""
for doc_id in doc_ids: for doc_id in doc_ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()
async def drop(self) -> None:
raise NotImplementedError

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any, final
from lightrag.base import ( from lightrag.base import (
BaseKVStorage, BaseKVStorage,
@@ -13,6 +13,7 @@ from lightrag.utils import (
) )
@final
@dataclass @dataclass
class JsonKVStorage(BaseKVStorage): class JsonKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
@@ -22,10 +23,10 @@ class JsonKVStorage(BaseKVStorage):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def index_done_callback(self): async def index_done_callback(self) -> None:
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return self._data.get(id) return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -38,8 +39,8 @@ class JsonKVStorage(BaseKVStorage):
for id in ids for id in ids
] ]
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
return set(data) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -10,17 +11,21 @@ import configparser
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
pm.install("pymilvus") pm.install("pymilvus")
from pymilvus import MilvusClient
try:
from pymilvus import MilvusClient
except ImportError as e:
raise ImportError(
"`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class MilvusVectorDBStorage(BaseVectorStorage): class MilvusVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs client: MilvusClient, collection_name: str, **kwargs
@@ -71,7 +76,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
dimension=self.embedding_func.embedding_dim, dimension=self.embedding_func.embedding_dim,
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -106,7 +111,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
results = self._client.upsert(collection_name=self.namespace, data=list_data) results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results return results
async def query(self, query, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,
@@ -123,3 +128,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]} {**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0] for dp in results[0]
] ]
async def index_done_callback(self) -> None:
# Milvus handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError

View File

@@ -1,22 +1,11 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pipmaster as pm
import configparser import configparser
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
import asyncio import asyncio
if not pm.is_installed("pymongo"): from typing import Any, List, Union, final
pm.install("pymongo")
if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -29,12 +18,29 @@ from ..base import (
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm
if not pm.is_installed("pymongo"):
pm.install("pymongo")
if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
raise ImportError(
"`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
@@ -60,17 +66,17 @@ class MongoKVStorage(BaseKVStorage):
# Ensure collection exists # Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name) create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}}) cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list() return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor} existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids return keys - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@@ -107,11 +113,16 @@ class MongoKVStorage(BaseKVStorage):
else: else:
return None return None
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically
pass
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the collection""" """Drop the collection"""
await self._data.drop() await self._data.drop()
@final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self): def __post_init__(self):
@@ -191,7 +202,12 @@ class MongoDocStatusStorage(DocStatusStorage):
for doc in result for doc in result
} }
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically
pass
@final
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):
""" """
@@ -429,7 +445,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
""" """
Return the full node document (including "edges"), or None if missing. Return the full node document (including "edges"), or None if missing.
""" """
@@ -437,11 +453,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Return the first edge dict from source_node_id to target_node_id if it exists.
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
"""
pipeline = [ pipeline = [
{"$match": {"_id": source_node_id}}, {"$match": {"_id": source_node_id}},
{ {
@@ -467,9 +479,7 @@ class MongoGraphStorage(BaseGraphStorage):
return e return e
return None return None
async def get_node_edges( async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
""" """
Return a list of (source_id, target_id) for direct edges from source_node_id. Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
@@ -503,7 +513,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def upsert_node(self, node_id: str, node_data: dict): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Insert or update a node document. If new, create an empty edges array. Insert or update a node document. If new, create an empty edges array.
""" """
@@ -513,8 +523,8 @@ class MongoGraphStorage(BaseGraphStorage):
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge from source_node_id -> target_node_id with optional 'relation'. Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data. If an edge with the same target exists, we remove it and re-insert with updated data.
@@ -540,7 +550,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str) -> None:
""" """
1) Remove node's doc entirely. 1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id. 2) Remove inbound edges from any doc that references node_id.
@@ -557,7 +567,9 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
""" """
Placeholder for demonstration, raises NotImplementedError. Placeholder for demonstration, raises NotImplementedError.
""" """
@@ -759,11 +771,14 @@ class MongoGraphStorage(BaseGraphStorage):
return result return result
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically
pass
@final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -828,7 +843,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
except PyMongoError as _: except PyMongoError as _:
logger.debug("vector index already exist") logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
if not data: if not data:
logger.warning("You are inserting an empty data set to vector DB") logger.warning("You are inserting an empty data set to vector DB")
@@ -871,7 +886,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query(self, query, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding # Generate the embedding
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
@@ -905,6 +920,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
for doc in results for doc in results
] ]
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it.""" """Check if the collection exists. if not, create it."""

View File

@@ -1,80 +1,35 @@
"""
NanoVectorDB Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio import asyncio
import os import os
from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
import time import time
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
compute_mdhash_id, compute_mdhash_id,
) )
import pipmaster as pm
from lightrag.base import ( from lightrag.base import (
BaseVectorStorage, BaseVectorStorage,
) )
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
try:
from nano_vectordb import NanoVectorDB
except ImportError as e:
raise ImportError(
"`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`."
) from e
@final
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._save_lock = asyncio.Lock() self._save_lock = asyncio.Lock()
@@ -95,7 +50,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.embedding_func.embedding_dim, storage_file=self._client_file_name self.embedding_func.embedding_dim, storage_file=self._client_file_name
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -139,7 +94,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
) )
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
results = self._client.query( results = self._client.query(
@@ -176,7 +131,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str): async def delete_entity(self, entity_name: str) -> None:
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug( logger.debug(
@@ -211,7 +166,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self): async def index_done_callback(self) -> None:
# Protect file write operation
async with self._save_lock: async with self._save_lock:
self._client.save() self._client.save()

View File

@@ -3,20 +3,11 @@ import inspect
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict from typing import Any, List, Dict, final
import pipmaster as pm import numpy as np
import configparser import configparser
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@@ -27,12 +18,29 @@ from tenacity import (
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
try:
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
except ImportError as e:
raise ImportError(
"`neo4j` library is not installed. Please install it via pip: `pip install neo4j`."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -140,8 +148,9 @@ class Neo4JStorage(BaseGraphStorage):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
async def index_done_callback(self): async def index_done_callback(self) -> None:
print("KG successfully indexed.") # Noe4J handles persistence automatically
pass
async def _label_exists(self, label: str) -> bool: async def _label_exists(self, label: str) -> bool:
"""Check if a label exists in the Neo4j database.""" """Check if a label exists in the Neo4j database."""
@@ -191,7 +200,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
return single_result["edgeExists"] return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier. """Get node by its label identifier.
Args: Args:
@@ -252,17 +261,7 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""Find edge between two nodes identified by their labels.
Args:
source_node_id (str): Label of the source node
target_node_id (str): Label of the target node
Returns:
dict: Edge properties if found, with at least {"weight": 0.0}
None: If error occurs
"""
try: try:
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
@@ -321,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage):
# Return default edge properties on error # Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None} return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
node_label = source_node_id.strip('"') node_label = source_node_id.strip('"')
""" """
@@ -364,7 +363,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
), ),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Neo4j database. Upsert a node in the Neo4j database.
@@ -405,8 +404,8 @@ class Neo4JStorage(BaseGraphStorage):
), ),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -603,7 +602,7 @@ class Neo4JStorage(BaseGraphStorage):
await traverse(label, 0) await traverse(label, 0)
return result return result
async def get_all_labels(self) -> List[str]: async def get_all_labels(self) -> list[str]:
""" """
Get all existing node labels in the database Get all existing node labels in the database
Returns: Returns:
@@ -627,3 +626,11 @@ class Neo4JStorage(BaseGraphStorage):
async for record in result: async for record in result:
labels.append(record["label"]) labels.append(record["label"])
return labels return labels
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError

View File

@@ -1,61 +1,12 @@
"""
NetworkX Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import html import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, cast, final
import networkx as nx
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
) )
@@ -64,7 +15,15 @@ from lightrag.base import (
BaseGraphStorage, BaseGraphStorage,
) )
try:
import networkx as nx
except ImportError as e:
raise ImportError(
"`networkx` library is not installed. Please install it via pip: `pip install networkx`."
) from e
@final
@dataclass @dataclass
class NetworkXStorage(BaseGraphStorage): class NetworkXStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -142,7 +101,7 @@ class NetworkXStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
@@ -151,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id) return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id) return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
@@ -162,35 +121,32 @@ class NetworkXStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id)) return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id): if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id)) return list(self._graph.edges(source_node_id))
return None return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data) self._graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data) self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str) -> None:
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id): if self._graph.has_node(node_id):
self._graph.remove_node(node_id) self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.") logger.info(f"Node {node_id} deleted from the graph.")
else: else:
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -226,3 +182,11 @@ class NetworkXStorage(BaseGraphStorage):
for source, target in edges: for source, target in edges:
if self._graph.has_edge(source, target): if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target) self._graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -4,16 +4,11 @@ import asyncio
# import html # import html
# import os # import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any, Union, final
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("oracledb"): from lightrag.types import KnowledgeGraph
pm.install("oracledb")
import oracledb
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -23,6 +18,19 @@ from ..base import (
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
import pipmaster as pm
if not pm.is_installed("oracledb"):
pm.install("oracledb")
try:
import oracledb
except ImportError as e:
raise ImportError(
"`oracledb` library is not installed. Please install it via pip: `pip install oracledb`."
) from e
class OracleDB: class OracleDB:
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
@@ -169,6 +177,7 @@ class OracleDB:
raise raise
@final
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -181,7 +190,7 @@ class OracleKVStorage(BaseKVStorage):
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data based on id.""" """Get doc_full data based on id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
@@ -232,7 +241,7 @@ class OracleKVStorage(BaseKVStorage):
res = [{k: v} for k, v in dict_res.items()] res = [{k: v} for k, v in dict_res.items()]
return res return res
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
@@ -248,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
return set(keys) return set(keys)
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [ list_data = [
{ {
@@ -307,20 +316,17 @@ class OracleKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self): async def index_done_callback(self) -> None:
if is_namespace( # Oracle handles persistence automatically
self.namespace, pass
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
): async def drop(self) -> None:
logger.info("full doc and chunk data had been saved into oracle db!") raise NotImplementedError
@final
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: OracleDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
@@ -330,16 +336,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据"""
pass
async def index_done_callback(self):
pass
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""从向量数据库中查询数据"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
# 转换精度 # 转换精度
@@ -359,21 +357,29 @@ class OracleVectorDBStorage(BaseVectorStorage):
# print("vector search result:",results) # print("vector search result:",results)
return results return results
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
raise NotImplementedError
async def index_done_callback(self) -> None:
# Oracles handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
@final
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: OracleDB
def __post_init__(self): def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################ #################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""插入或更新节点"""
# print("go into upsert node method")
entity_name = node_id entity_name = node_id
entity_type = node_data["entity_type"] entity_type = node_data["entity_type"]
description = node_data["description"] description = node_data["description"]
@@ -406,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
"""插入或更新边""" """插入或更新边"""
# print("go into upsert edge method") # print("go into upsert edge method")
source_name = source_node_id source_name = source_node_id
@@ -446,8 +452,9 @@ class OracleGraphStorage(BaseGraphStorage):
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(
"""为节点生成向量""" self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -464,11 +471,9 @@ class OracleGraphStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""写入graphhml图文件""" # Oracles handles persistence automatically
logger.info( pass
"Node and edge data had been saved into oracle db already, so nothing to do here!"
)
#################### query method ################# #################### query method #################
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
@@ -486,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
return False return False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"] SQL = SQL_TEMPLATES["has_edge"]
params = { params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
@@ -503,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
return False return False
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"] SQL = SQL_TEMPLATES["node_degree"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL) # print(SQL)
@@ -521,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Edge degree",degree) # print("Edge degree",degree)
return degree return degree
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""根据节点id获取节点数据""" """根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"] SQL = SQL_TEMPLATES["get_node"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
@@ -537,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"] SQL = SQL_TEMPLATES["get_edge"]
params = { params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
@@ -553,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
return None return None
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id): if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"] SQL = SQL_TEMPLATES["get_node_edges"]
params = {"workspace": self.db.workspace, "source_node_id": source_node_id} params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
@@ -590,6 +591,17 @@ class OracleGraphStorage(BaseGraphStorage):
if res: if res:
return res return res
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -4,24 +4,19 @@ import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Union from typing import Any, Dict, List, Union, final
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("asyncpg"): from lightrag.types import KnowledgeGraph
pm.install("asyncpg")
import sys import sys
import asyncpg
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception_type, retry_if_exception_type,
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from tqdm.asyncio import tqdm as tqdm_async
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -39,6 +34,20 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import pipmaster as pm
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
try:
import asyncpg
from tqdm.asyncio import tqdm as tqdm_async
except ImportError as e:
raise ImportError(
"`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`."
) from e
class PostgreSQLDB: class PostgreSQLDB:
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
@@ -175,6 +184,7 @@ class PostgreSQLDB:
pass pass
@final
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -185,7 +195,7 @@ class PGKVStorage(BaseKVStorage):
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data by id.""" """Get doc_full data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace] sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
@@ -240,7 +250,7 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def filter_keys(self, keys: List[str]) -> Set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
@@ -261,7 +271,7 @@ class PGKVStorage(BaseKVStorage):
print(params) print(params)
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass pass
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -287,20 +297,17 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self): async def index_done_callback(self) -> None:
if is_namespace( # PG handles persistence automatically
self.namespace, pass
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
): async def drop(self) -> None:
logger.info("full doc and chunk data had been saved into postgresql db!") raise NotImplementedError
@final
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -352,7 +359,7 @@ class PGVectorStorage(BaseVectorStorage):
} }
return upsert_sql, data return upsert_sql, data
async def upsert(self, data: Dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -398,12 +405,8 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
async def index_done_callback(self):
logger.info("vector data had been saved into postgresql db!")
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""从向量数据库中查询数据"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
@@ -417,23 +420,31 @@ class PGVectorStorage(BaseVectorStorage):
results = await self.db.query(sql, params=params, multirows=True) results = await self.db.query(sql, params=params, multirows=True)
return results return results
async def index_done_callback(self) -> None:
# PG handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
@final
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
# db instance must be injected before use async def filter_keys(self, keys: set[str]) -> set[str]:
# db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data]) keys = ",".join([f"'{_id}'" for _id in keys])
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
result = await self.db.query(sql, multirows=True) result = await self.db.query(sql, multirows=True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None: if result is None:
return set(data) return set(keys)
else: else:
existed = set([element["id"] for element in result]) existed = set([element["id"] for element in result])
return set(data) - existed return set(keys) - existed
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
@@ -452,6 +463,9 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at=result[0]["updated_at"], updated_at=result[0]["updated_at"],
) )
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
raise NotImplementedError
async def get_status_counts(self) -> Dict[str, int]: async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count" sql = """SELECT status as "status", COUNT(1) as "count"
@@ -470,7 +484,7 @@ class PGDocStatusStorage(DocStatusStorage):
) -> Dict[str, DocProcessingStatus]: ) -> Dict[str, DocProcessingStatus]:
"""all documents with a specific status""" """all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status.value}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, params, True)
return { return {
element["id"]: DocProcessingStatus( element["id"]: DocProcessingStatus(
@@ -485,11 +499,11 @@ class PGDocStatusStorage(DocStatusStorage):
for element in result for element in result
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" # PG handles persistence automatically
logger.info("Doc status had been saved into postgresql db!") pass
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Update or insert document status """Update or insert document status
Args: Args:
@@ -520,31 +534,8 @@ class PGDocStatusStorage(DocStatusStorage):
) )
return data return data
async def update_doc_status(self, data: dict[str, dict]) -> None: async def drop(self) -> None:
""" raise NotImplementedError
Updates only the document status, chunk count, and updated timestamp.
This method ensures that only relevant fields are updated instead of overwriting
the entire document record. If `updated_at` is not provided, the database will
automatically use the current timestamp.
"""
sql = """
UPDATE LIGHTRAG_DOC_STATUS
SET status = $3,
chunks_count = $4,
updated_at = CURRENT_TIMESTAMP
WHERE workspace = $1 AND id = $2
"""
for k, v in data.items():
_data = {
"workspace": self.db.workspace,
"id": k,
"status": v["status"].value, # Convert Enum to string
"chunks_count": v.get(
"chunks_count", -1
), # Default to -1 if not provided
}
await self.db.execute(sql, _data)
class PGGraphQueryException(Exception): class PGGraphQueryException(Exception):
@@ -565,11 +556,9 @@ class PGGraphQueryException(Exception):
return self.details return self.details
@final
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with AGE in production") print("no preloading of graph with AGE in production")
@@ -580,8 +569,9 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
print("KG successfully indexed.") # PG handles persistence automatically
pass
@staticmethod @staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
@@ -811,7 +801,7 @@ class PGGraphStorage(BaseGraphStorage):
) )
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {node_id: "%s"})
@@ -866,17 +856,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Find all edges between nodes of two given labels
Args:
source_node_id (str): Label of the source nodes
target_node_id (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
@@ -900,7 +880,7 @@ class PGGraphStorage(BaseGraphStorage):
) )
return result return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
@@ -948,14 +928,7 @@ class PGGraphStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((PGGraphQueryException,)), retry=retry_if_exception_type((PGGraphQueryException,)),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the AGE database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
properties = node_data properties = node_data
@@ -986,8 +959,8 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)), retry=retry_if_exception_type((PGGraphQueryException,)),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -1029,6 +1002,22 @@ class PGGraphStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
NAMESPACE_TABLE_MAP = { NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -7,16 +8,24 @@ import hashlib
import uuid import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm
import configparser import configparser
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
import pipmaster as pm
if not pm.is_installed("qdrant_client"): if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client") pm.install("qdrant_client")
from qdrant_client import QdrantClient, models try:
from qdrant_client import QdrantClient, models
config = configparser.ConfigParser() except ImportError:
config.read("config.ini", "utf-8") raise ImportError(
"`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`."
)
def compute_mdhash_id_for_qdrant( def compute_mdhash_id_for_qdrant(
@@ -47,10 +56,9 @@ def compute_mdhash_id_for_qdrant(
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.") raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
@final
@dataclass @dataclass
class QdrantVectorDBStorage(BaseVectorStorage): class QdrantVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs client: QdrantClient, collection_name: str, **kwargs
@@ -85,7 +93,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
), ),
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
return [] return []
@@ -130,7 +138,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query(self, query, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,
@@ -143,3 +151,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
logger.debug(f"query result: {results}") logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
async def index_done_callback(self) -> None:
# Qdrant handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Union from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
@@ -19,6 +19,7 @@ config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class RedisKVStorage(BaseKVStorage): class RedisKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
@@ -28,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
self._redis = Redis.from_url(redis_url, decode_responses=True) self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}") logger.info(f"Use Redis as KV {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
data = await self._redis.get(f"{self.namespace}:{id}") data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None return json.loads(data) if data else None
@@ -39,16 +40,16 @@ class RedisKVStorage(BaseKVStorage):
results = await pipe.execute() results = await pipe.execute()
return [json.loads(result) if result else None for result in results] return [json.loads(result) if result else None for result in results]
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for key in data: for key in keys:
pipe.exists(f"{self.namespace}:{key}") pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute() results = await pipe.execute()
existing_ids = {data[i] for i, exists in enumerate(results) if exists} existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(data) - existing_ids return set(keys) - existing_ids
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for k, v in tqdm_async(data.items(), desc="Upserting"): for k, v in tqdm_async(data.items(), desc="Upserting"):
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@@ -61,3 +62,7 @@ class RedisKVStorage(BaseKVStorage):
keys = await self._redis.keys(f"{self.namespace}:*") keys = await self._redis.keys(f"{self.namespace}:*")
if keys: if keys:
await self._redis.delete(*keys) await self._redis.delete(*keys)
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass

View File

@@ -1,9 +1,18 @@
import asyncio import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any, Union, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from tqdm import tqdm
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace
from ..utils import logger
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymysql"): if not pm.is_installed("pymysql"):
@@ -11,12 +20,13 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"): if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy") pm.install("sqlalchemy")
from sqlalchemy import create_engine, text try:
from tqdm import tqdm from sqlalchemy import create_engine, text
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage except ImportError as e:
from ..namespace import NameSpace, is_namespace raise ImportError(
from ..utils import logger "`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`."
) from e
class TiDB: class TiDB:
@@ -99,6 +109,7 @@ class TiDB:
raise raise
@final
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -110,7 +121,7 @@ class TiDBKVStorage(BaseKVStorage):
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Fetch doc_full data by id.""" """Fetch doc_full data by id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id} params = {"id": id}
@@ -125,8 +136,7 @@ class TiDBKVStorage(BaseKVStorage):
) )
return await self.db.query(SQL, multirows=True) return await self.db.query(SQL, multirows=True)
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
id_field=namespace_to_id(self.namespace), id_field=namespace_to_id(self.namespace),
@@ -147,7 +157,7 @@ class TiDBKVStorage(BaseKVStorage):
return data return data
################ INSERT full_doc AND chunks ################ ################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -200,20 +210,17 @@ class TiDBKVStorage(BaseKVStorage):
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
return left_data return left_data
async def index_done_callback(self): async def index_done_callback(self) -> None:
if is_namespace( # Ti handles persistence automatically
self.namespace, pass
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
): async def drop(self) -> None:
logger.info("full doc and chunk data had been saved into TiDB db!") raise NotImplementedError
@final
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -227,7 +234,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector""" """Search from tidb vector"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
@@ -249,7 +256,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
return results return results
###### INSERT entities And relationships ###### ###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
# ignore, upsert in TiDBKVStorage already # ignore, upsert in TiDBKVStorage already
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -332,7 +339,18 @@ class TiDBVectorDBStorage(BaseVectorStorage):
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
async def index_done_callback(self) -> None:
# Ti handles persistence automatically
pass
@final
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -342,7 +360,7 @@ class TiDBGraphStorage(BaseGraphStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
#################### upsert method ################ #################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id entity_name = node_id
entity_type = node_data["entity_type"] entity_type = node_data["entity_type"]
description = node_data["description"] description = node_data["description"]
@@ -373,7 +391,7 @@ class TiDBGraphStorage(BaseGraphStorage):
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
source_name = source_node_id source_name = source_node_id
target_name = target_node_id target_name = target_node_id
weight = edge_data["weight"] weight = edge_data["weight"]
@@ -409,7 +427,9 @@ class TiDBGraphStorage(BaseGraphStorage):
} }
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -442,14 +462,14 @@ class TiDBGraphStorage(BaseGraphStorage):
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree return degree
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_node"] sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace} param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param) return await self.db.query(sql, param)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_edge"] sql = SQL_TEMPLATES["get_edge"]
param = { param = {
"source_name": source_node_id, "source_name": source_node_id,
@@ -458,9 +478,7 @@ class TiDBGraphStorage(BaseGraphStorage):
} }
return await self.db.query(sql, param) return await self.db.query(sql, param)
async def get_node_edges( async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
sql = SQL_TEMPLATES["get_node_edges"] sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace} param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True) res = await self.db.query(sql, param, multirows=True)
@@ -470,6 +488,21 @@ class TiDBGraphStorage(BaseGraphStorage):
else: else:
return [] return []
async def index_done_callback(self) -> None:
# Ti handles persistence automatically
pass
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -674,7 +674,7 @@ class LightRAG:
"content": content, "content": content,
"content_summary": self._get_content_summary(content), "content_summary": self._get_content_summary(content),
"content_length": len(content), "content_length": len(content),
"status": DocStatus.PENDING, "status": DocStatus.PENDING.value,
"created_at": datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
} }
@@ -745,7 +745,7 @@ class LightRAG:
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
doc_status_id: { doc_status_id: {
"status": DocStatus.PROCESSING, "status": DocStatus.PROCESSING.value,
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
@@ -779,10 +779,10 @@ class LightRAG:
] ]
try: try:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
await self.doc_status.update_doc_status( await self.doc_status.upsert(
{ {
doc_status_id: { doc_status_id: {
"status": DocStatus.PROCESSED, "status": DocStatus.PROCESSED.value,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
@@ -796,10 +796,10 @@ class LightRAG:
except Exception as e: except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}") logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.update_doc_status( await self.doc_status.upsert(
{ {
doc_status_id: { doc_status_id: {
"status": DocStatus.FAILED, "status": DocStatus.FAILED.value,
"error": str(e), "error": str(e),
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,