added final, required methods and cleaned import
This commit is contained in:
@@ -92,6 +92,7 @@ class StorageNameSpace(ABC):
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
embedding_func: EmbeddingFunc
|
||||
cosine_better_than_threshold: float
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -5,21 +5,11 @@ import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
pm.install("psycopg-pool")
|
||||
pm.install("psycopg[binary,pool]")
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -37,6 +27,16 @@ if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
@@ -55,6 +55,7 @@ class AGEQueryException(Exception):
|
||||
return self.details
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class AGEStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -100,9 +101,6 @@ class AGEStorage(BaseGraphStorage):
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -627,3 +625,6 @@ class AGEStorage(BaseGraphStorage):
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
@@ -1,19 +1,25 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
|
||||
try:
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"chromadb library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
"""ChromaDB vector storage implementation."""
|
||||
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
try:
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import faiss
|
||||
from typing import Any, final
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
@@ -16,7 +16,15 @@ from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"faiss library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
@@ -24,8 +32,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||
"""
|
||||
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Grab config values if available
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
|
@@ -3,13 +3,11 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -22,7 +20,17 @@ from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
try:
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"gremlin library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class GremlinStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -79,8 +87,8 @@ class GremlinStorage(BaseGraphStorage):
|
||||
if self._driver:
|
||||
self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _to_value_map(value: Any) -> str:
|
||||
|
@@ -1,56 +1,6 @@
|
||||
"""
|
||||
JsonDocStatus Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
from lightrag.base import (
|
||||
DocProcessingStatus,
|
||||
@@ -64,6 +14,7 @@ from lightrag.utils import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""JSON implementation of document status storage"""
|
||||
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
return set(data) - set(self._data.keys())
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
result: list[dict[str, Any]] = []
|
||||
@@ -94,7 +45,6 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
return counts
|
||||
|
||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
return {
|
||||
k: DocProcessingStatus(**v)
|
||||
for k, v in self._data.items()
|
||||
@@ -102,7 +52,6 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
}
|
||||
|
||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all pending documents"""
|
||||
return {
|
||||
k: DocProcessingStatus(**v)
|
||||
for k, v in self._data.items()
|
||||
@@ -110,7 +59,6 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
}
|
||||
|
||||
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all processed documents"""
|
||||
return {
|
||||
k: DocProcessingStatus(**v)
|
||||
for k, v in self._data.items()
|
||||
@@ -118,23 +66,16 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
}
|
||||
|
||||
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all processing documents"""
|
||||
return {
|
||||
k: DocProcessingStatus(**v)
|
||||
for k, v in self._data.items()
|
||||
if v["status"] == DocStatus.PROCESSING
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save data to file after indexing"""
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
"""Update or insert document status
|
||||
|
||||
Args:
|
||||
data: Dictionary of document IDs and their status data
|
||||
"""
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
|
||||
@@ -142,7 +83,12 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
"""Delete document status by IDs"""
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
@@ -22,7 +22,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
|
@@ -1,27 +1,29 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
import pipmaster as pm
|
||||
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
from pymilvus import MilvusClient
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pymilvus library is not installed. Please install it to proceed."
|
||||
)
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: MilvusClient, collection_name: str, **kwargs
|
||||
|
@@ -1,22 +1,11 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
import asyncio
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
pm.install("pymongo")
|
||||
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from typing import Any, List, Union
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
from typing import Any, List, Union, final
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"motor, pymongo library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
@@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage):
|
||||
await self._data.drop()
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
@@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
"""Get all procesed documents"""
|
||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Implement the method here
|
||||
pass
|
||||
|
||||
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
@@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
|
@@ -1,65 +1,10 @@
|
||||
"""
|
||||
NanoVectorDB Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
@@ -71,11 +16,17 @@ from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
try:
|
||||
from nano_vectordb import NanoVectorDB
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"nano-vectordb library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
|
@@ -3,21 +3,11 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, List, Dict, final
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -29,11 +19,25 @@ from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
try:
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"neo4j library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -141,8 +145,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def _label_exists(self, label: str) -> bool:
|
||||
"""Check if a label exists in the Neo4j database."""
|
||||
|
@@ -1,58 +1,8 @@
|
||||
"""
|
||||
NetworkX Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
import networkx as nx
|
||||
from typing import Any, cast, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -65,7 +15,15 @@ from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"networkx library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
|
@@ -4,17 +4,11 @@ import asyncio
|
||||
# import html
|
||||
# import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
import oracledb
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -24,6 +18,14 @@ from ..base import (
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
try:
|
||||
import oracledb
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"oracledb library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class OracleDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -170,6 +172,7 @@ class OracleDB:
|
||||
raise
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
@@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
pass
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
#################### query method ###############
|
||||
@@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||
|
||||
#################### insert method ################
|
||||
@@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
return embeddings, nodes_ids
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
"""写入graphhml图文件"""
|
||||
logger.info(
|
||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
||||
)
|
||||
pass
|
||||
|
||||
#################### query method #################
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
|
@@ -4,26 +4,19 @@ import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import sys
|
||||
|
||||
import asyncpg
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -41,6 +34,15 @@ if sys.platform.startswith("win"):
|
||||
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"asyncpg, tqdm_async library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -177,6 +179,7 @@ class PostgreSQLDB:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage):
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into postgresql db!")
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
logger.info("vector data had been saved into postgresql db!")
|
||||
pass
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
@@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
keys = ",".join([f"'{_id}'" for _id in data])
|
||||
keys = ",".join([f"'{_id}'" for _id in keys])
|
||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||
result = await self.db.query(sql, multirows=True)
|
||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||
if result is None:
|
||||
return set(data)
|
||||
return set(keys)
|
||||
else:
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
return set(keys) - existed
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
@@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
updated_at=result[0]["updated_at"],
|
||||
)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||
@@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
"""Get all procesed documents"""
|
||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
||||
logger.info("Doc status had been saved into postgresql db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""Update or insert document status
|
||||
@@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
}
|
||||
await self.db.execute(sql, _data)
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PGGraphQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
@@ -593,11 +595,9 @@ class PGGraphQueryException(Exception):
|
||||
return self.details
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
@@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, final
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@@ -8,17 +8,20 @@ import hashlib
|
||||
import uuid
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("qdrant_client"):
|
||||
pm.install("qdrant_client")
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"qdrant_client library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
content: str, prefix: str = "", style: str = "simple"
|
||||
@@ -48,10 +51,9 @@ def compute_mdhash_id_for_qdrant(
|
||||
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: QdrantClient, collection_name: str, **kwargs
|
||||
|
@@ -1,24 +1,26 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymysql"):
|
||||
pm.install("pymysql")
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"pymysql, sqlalchemy library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class TiDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -100,6 +102,7 @@ class TiDB:
|
||||
raise
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
await self.db.execute(merge_sql, data)
|
||||
return left_data
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
@@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
return []
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
Reference in New Issue
Block a user