added final, required methods and cleaned import
This commit is contained in:
@@ -92,6 +92,7 @@ class StorageNameSpace(ABC):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace, ABC):
|
class BaseVectorStorage(StorageNameSpace, ABC):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
cosine_better_than_threshold: float
|
||||||
meta_fields: set[str] = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@@ -5,21 +5,11 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("psycopg-pool"):
|
|
||||||
pm.install("psycopg-pool")
|
|
||||||
pm.install("psycopg[binary,pool]")
|
|
||||||
if not pm.is_installed("asyncpg"):
|
|
||||||
pm.install("asyncpg")
|
|
||||||
|
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph
|
from lightrag.types import KnowledgeGraph
|
||||||
import psycopg
|
|
||||||
from psycopg.rows import namedtuple_row
|
|
||||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
@@ -37,6 +27,16 @@ if sys.platform.startswith("win"):
|
|||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import psycopg
|
||||||
|
from psycopg.rows import namedtuple_row
|
||||||
|
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class AGEQueryException(Exception):
|
class AGEQueryException(Exception):
|
||||||
"""Exception for the AGE queries."""
|
"""Exception for the AGE queries."""
|
||||||
|
|
||||||
@@ -55,6 +55,7 @@ class AGEQueryException(Exception):
|
|||||||
return self.details
|
return self.details
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class AGEStorage(BaseGraphStorage):
|
class AGEStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -100,9 +101,6 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
|
||||||
print("KG successfully indexed.")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -627,3 +625,6 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
pass
|
||||||
|
@@ -1,19 +1,25 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chromadb import HttpClient, PersistentClient
|
|
||||||
from chromadb.config import Settings
|
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from chromadb import HttpClient, PersistentClient
|
||||||
|
from chromadb.config import Settings
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"chromadb library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
"""ChromaDB vector storage implementation."""
|
"""ChromaDB vector storage implementation."""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
try:
|
try:
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any, final
|
||||||
import faiss
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
@@ -16,7 +16,15 @@ from lightrag.base import (
|
|||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"faiss library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class FaissVectorDBStorage(BaseVectorStorage):
|
class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
"""
|
"""
|
||||||
@@ -24,8 +32,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Grab config values if available
|
# Grab config values if available
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
@@ -3,13 +3,11 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gremlin_python.driver import client, serializer
|
|
||||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
|
||||||
from gremlin_python.driver.protocol import GremlinServerError
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
@@ -22,7 +20,17 @@ from lightrag.utils import logger
|
|||||||
|
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gremlin_python.driver import client, serializer
|
||||||
|
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||||
|
from gremlin_python.driver.protocol import GremlinServerError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"gremlin library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class GremlinStorage(BaseGraphStorage):
|
class GremlinStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -79,8 +87,8 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
if self._driver:
|
if self._driver:
|
||||||
self._driver.close()
|
self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
print("KG successfully indexed.")
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_value_map(value: Any) -> str:
|
def _to_value_map(value: Any) -> str:
|
||||||
|
@@ -1,56 +1,6 @@
|
|||||||
"""
|
|
||||||
JsonDocStatus Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
from typing import Any, Union
|
from typing import Any, Union, final
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
@@ -64,6 +14,7 @@ from lightrag.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonDocStatusStorage(DocStatusStorage):
|
class JsonDocStatusStorage(DocStatusStorage):
|
||||||
"""JSON implementation of document status storage"""
|
"""JSON implementation of document status storage"""
|
||||||
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
return set(data) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
result: list[dict[str, Any]] = []
|
result: list[dict[str, Any]] = []
|
||||||
@@ -94,7 +45,6 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all failed documents"""
|
|
||||||
return {
|
return {
|
||||||
k: DocProcessingStatus(**v)
|
k: DocProcessingStatus(**v)
|
||||||
for k, v in self._data.items()
|
for k, v in self._data.items()
|
||||||
@@ -102,7 +52,6 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all pending documents"""
|
|
||||||
return {
|
return {
|
||||||
k: DocProcessingStatus(**v)
|
k: DocProcessingStatus(**v)
|
||||||
for k, v in self._data.items()
|
for k, v in self._data.items()
|
||||||
@@ -110,7 +59,6 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all processed documents"""
|
|
||||||
return {
|
return {
|
||||||
k: DocProcessingStatus(**v)
|
k: DocProcessingStatus(**v)
|
||||||
for k, v in self._data.items()
|
for k, v in self._data.items()
|
||||||
@@ -118,23 +66,16 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all processing documents"""
|
|
||||||
return {
|
return {
|
||||||
k: DocProcessingStatus(**v)
|
k: DocProcessingStatus(**v)
|
||||||
for k, v in self._data.items()
|
for k, v in self._data.items()
|
||||||
if v["status"] == DocStatus.PROCESSING
|
if v["status"] == DocStatus.PROCESSING
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Save data to file after indexing"""
|
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
"""Update or insert document status
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Dictionary of document IDs and their status data
|
|
||||||
"""
|
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
@@ -142,7 +83,12 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def delete(self, doc_ids: list[str]):
|
async def delete(self, doc_ids: list[str]):
|
||||||
"""Delete document status by IDs"""
|
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
self._data.pop(doc_id, None)
|
self._data.pop(doc_id, None)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
|
async def drop(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -22,7 +22,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
|
@@ -1,27 +1,29 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
import pipmaster as pm
|
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("pymilvus"):
|
try:
|
||||||
pm.install("pymilvus")
|
from pymilvus import MilvusClient
|
||||||
from pymilvus import MilvusClient
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pymilvus library is not installed. Please install it to proceed."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: MilvusClient, collection_name: str, **kwargs
|
client: MilvusClient, collection_name: str, **kwargs
|
||||||
|
@@ -1,22 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
import configparser
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
from typing import Any, List, Union, final
|
||||||
pm.install("pymongo")
|
|
||||||
|
|
||||||
if not pm.is_installed("motor"):
|
|
||||||
pm.install("motor")
|
|
||||||
|
|
||||||
from typing import Any, List, Union
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.operations import SearchIndexModel
|
|
||||||
from pymongo.errors import PyMongoError
|
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace
|
|||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
try:
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"motor, pymongo library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
await self._data.drop()
|
await self._data.drop()
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
"""Get all procesed documents"""
|
"""Get all procesed documents"""
|
||||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Implement the method here
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoGraphStorage(BaseGraphStorage):
|
class MongoGraphStorage(BaseGraphStorage):
|
||||||
"""
|
"""
|
||||||
@@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoVectorDBStorage(BaseVectorStorage):
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
|
@@ -1,65 +1,10 @@
|
|||||||
"""
|
|
||||||
NanoVectorDB Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("nano-vectordb"):
|
|
||||||
pm.install("nano-vectordb")
|
|
||||||
|
|
||||||
from nano_vectordb import NanoVectorDB
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
@@ -71,11 +16,17 @@ from lightrag.base import (
|
|||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nano_vectordb import NanoVectorDB
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"nano-vectordb library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize lock only for file operations
|
# Initialize lock only for file operations
|
||||||
self._save_lock = asyncio.Lock()
|
self._save_lock = asyncio.Lock()
|
||||||
|
@@ -3,21 +3,11 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Dict
|
from typing import Any, List, Dict, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
|
||||||
pm.install("neo4j")
|
|
||||||
|
|
||||||
from neo4j import (
|
|
||||||
AsyncGraphDatabase,
|
|
||||||
exceptions as neo4jExceptions,
|
|
||||||
AsyncDriver,
|
|
||||||
AsyncManagedTransaction,
|
|
||||||
GraphDatabase,
|
|
||||||
)
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
@@ -29,11 +19,25 @@ from ..utils import logger
|
|||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
try:
|
||||||
|
from neo4j import (
|
||||||
|
AsyncGraphDatabase,
|
||||||
|
exceptions as neo4jExceptions,
|
||||||
|
AsyncDriver,
|
||||||
|
AsyncManagedTransaction,
|
||||||
|
GraphDatabase,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"neo4j library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -141,8 +145,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
print("KG successfully indexed.")
|
pass
|
||||||
|
|
||||||
async def _label_exists(self, label: str) -> bool:
|
async def _label_exists(self, label: str) -> bool:
|
||||||
"""Check if a label exists in the Neo4j database."""
|
"""Check if a label exists in the Neo4j database."""
|
||||||
|
@@ -1,58 +1,8 @@
|
|||||||
"""
|
|
||||||
NetworkX Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast
|
from typing import Any, cast, final
|
||||||
import networkx as nx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +15,15 @@ from lightrag.base import (
|
|||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import networkx as nx
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"networkx library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NetworkXStorage(BaseGraphStorage):
|
class NetworkXStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@@ -4,17 +4,11 @@ import asyncio
|
|||||||
# import html
|
# import html
|
||||||
# import os
|
# import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("oracledb"):
|
|
||||||
pm.install("oracledb")
|
|
||||||
|
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph
|
from lightrag.types import KnowledgeGraph
|
||||||
import oracledb
|
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -24,6 +18,14 @@ from ..base import (
|
|||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import oracledb
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"oracledb library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class OracleDB:
|
class OracleDB:
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
@@ -170,6 +172,7 @@ class OracleDB:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleKVStorage(BaseKVStorage):
|
class OracleKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: OracleDB
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
@@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
@@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleGraphStorage(BaseGraphStorage):
|
class OracleGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: OracleDB
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""从graphml文件加载图"""
|
|
||||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||||
|
|
||||||
#################### insert method ################
|
#################### insert method ################
|
||||||
@@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
return embeddings, nodes_ids
|
return embeddings, nodes_ids
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
"""写入graphhml图文件"""
|
pass
|
||||||
logger.info(
|
|
||||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
|
||||||
)
|
|
||||||
|
|
||||||
#################### query method #################
|
#################### query method #################
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
@@ -4,26 +4,19 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
if not pm.is_installed("asyncpg"):
|
|
||||||
pm.install("asyncpg")
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -41,6 +34,15 @@ if sys.platform.startswith("win"):
|
|||||||
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncpg
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"asyncpg, tqdm_async library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLDB:
|
class PostgreSQLDB:
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
@@ -177,6 +179,7 @@ class PostgreSQLDB:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGKVStorage(BaseKVStorage):
|
class PGKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
if is_namespace(
|
pass
|
||||||
self.namespace,
|
|
||||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
|
||||||
):
|
|
||||||
logger.info("full doc and chunk data had been saved into postgresql db!")
|
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: PostgreSQLDB
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
@@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
await self.db.execute(upsert_sql, data)
|
await self.db.execute(upsert_sql, data)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
logger.info("vector data had been saved into postgresql db!")
|
pass
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
@@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGDocStatusStorage(DocStatusStorage):
|
class PGDocStatusStorage(DocStatusStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
# db: PostgreSQLDB
|
# db: PostgreSQLDB
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
keys = ",".join([f"'{_id}'" for _id in data])
|
keys = ",".join([f"'{_id}'" for _id in keys])
|
||||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||||
result = await self.db.query(sql, multirows=True)
|
result = await self.db.query(sql, multirows=True)
|
||||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||||
if result is None:
|
if result is None:
|
||||||
return set(data)
|
return set(keys)
|
||||||
else:
|
else:
|
||||||
existed = set([element["id"] for element in result])
|
existed = set([element["id"] for element in result])
|
||||||
return set(data) - existed
|
return set(keys) - existed
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||||
@@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
updated_at=result[0]["updated_at"],
|
updated_at=result[0]["updated_at"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_status_counts(self) -> Dict[str, int]:
|
async def get_status_counts(self) -> Dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||||
@@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
"""Get all procesed documents"""
|
"""Get all procesed documents"""
|
||||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
pass
|
||||||
logger.info("Doc status had been saved into postgresql db!")
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
"""Update or insert document status
|
"""Update or insert document status
|
||||||
@@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
}
|
}
|
||||||
await self.db.execute(sql, _data)
|
await self.db.execute(sql, _data)
|
||||||
|
|
||||||
|
async def drop(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PGGraphQueryException(Exception):
|
class PGGraphQueryException(Exception):
|
||||||
"""Exception for the AGE queries."""
|
"""Exception for the AGE queries."""
|
||||||
@@ -593,11 +595,9 @@ class PGGraphQueryException(Exception):
|
|||||||
return self.details
|
return self.details
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGGraphStorage(BaseGraphStorage):
|
class PGGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: PostgreSQLDB
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
print("no preloading of graph with AGE in production")
|
print("no preloading of graph with AGE in production")
|
||||||
@@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
print("KG successfully indexed.")
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any, final
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -8,17 +8,20 @@ import hashlib
|
|||||||
import uuid
|
import uuid
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
import pipmaster as pm
|
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("qdrant_client"):
|
|
||||||
pm.install("qdrant_client")
|
|
||||||
|
|
||||||
from qdrant_client import QdrantClient, models
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from qdrant_client import QdrantClient, models
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"qdrant_client library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id_for_qdrant(
|
def compute_mdhash_id_for_qdrant(
|
||||||
content: str, prefix: str = "", style: str = "simple"
|
content: str, prefix: str = "", style: str = "simple"
|
||||||
@@ -48,10 +51,9 @@ def compute_mdhash_id_for_qdrant(
|
|||||||
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: QdrantClient, collection_name: str, **kwargs
|
client: QdrantClient, collection_name: str, **kwargs
|
||||||
|
@@ -1,24 +1,26 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, Union, final
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("pymysql"):
|
|
||||||
pm.install("pymysql")
|
|
||||||
if not pm.is_installed("sqlalchemy"):
|
|
||||||
pm.install("sqlalchemy")
|
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph
|
from lightrag.types import KnowledgeGraph
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"pymysql, sqlalchemy library is not installed. Please install it to proceed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class TiDB:
|
class TiDB:
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
@@ -100,6 +102,7 @@ class TiDB:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBKVStorage(BaseKVStorage):
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
return left_data
|
return left_data
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
if is_namespace(
|
pass
|
||||||
self.namespace,
|
|
||||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
|
||||||
):
|
|
||||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
|
||||||
# db: TiDB
|
|
||||||
cosine_better_than_threshold: float = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
@@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
"""Delete relations for a given entity by scanning metadata"""
|
"""Delete relations for a given entity by scanning metadata"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
@@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user