implemented method and cleaned the mess
This commit is contained in:
@@ -1,53 +1,3 @@
|
|||||||
"""
|
|
||||||
JsonDocStatus Storage Module
|
|
||||||
=======================
|
|
||||||
|
|
||||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
||||||
|
|
||||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
||||||
|
|
||||||
Author: lightrag team
|
|
||||||
Created: 2024-01-25
|
|
||||||
License: MIT
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Version: 1.0.0
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- NetworkX
|
|
||||||
- NumPy
|
|
||||||
- LightRAG
|
|
||||||
- graspologic
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
||||||
- Query graph nodes and edges
|
|
||||||
- Calculate node and edge degrees
|
|
||||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
||||||
- Remove nodes and edges from the graph
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -58,12 +8,10 @@ from lightrag.utils import (
|
|||||||
load_json,
|
load_json,
|
||||||
write_json,
|
write_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonKVStorage(BaseKVStorage):
|
class JsonKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -79,10 +27,10 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str):
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.get(id, None)
|
return self._data.get(id, None)
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]):
|
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
{k: v for k, v in self._data[id].items()}
|
{k: v for k, v in self._data[id].items()}
|
||||||
@@ -95,12 +43,11 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||||
return set([s for s in data if s not in self._data])
|
return set([s for s in data if s not in self._data])
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
self._data.update(left_data)
|
self._data.update(left_data)
|
||||||
return left_data
|
|
||||||
|
|
||||||
async def drop(self):
|
async def drop(self) -> None:
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
async def get_by_status_and_ids(
|
async def get_by_status_and_ids(
|
||||||
|
@@ -50,7 +50,7 @@ Usage:
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union, Dict
|
from typing import Any, Union, Dict
|
||||||
|
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
@@ -104,7 +104,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
"""Save data to file after indexing"""
|
"""Save data to file after indexing"""
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
"""Update or insert document status
|
"""Update or insert document status
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -114,7 +114,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def get_by_id(self, id: str):
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
||||||
|
@@ -12,7 +12,7 @@ if not pm.is_installed("motor"):
|
|||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from typing import Union, List, Tuple
|
from typing import Any, TypeVar, Union, List, Tuple
|
||||||
|
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseKVStorage, BaseGraphStorage
|
from ..base import BaseKVStorage, BaseGraphStorage
|
||||||
@@ -32,18 +32,11 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
async def all_keys(self) -> list[str]:
|
async def all_keys(self) -> list[str]:
|
||||||
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
||||||
|
|
||||||
async def get_by_id(self, id):
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids, fields=None):
|
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||||
if fields is None:
|
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
return list(self._data.find({"_id": {"$in": ids}}))
|
||||||
return list(
|
|
||||||
self._data.find(
|
|
||||||
{"_id": {"$in": ids}},
|
|
||||||
{field: 1 for field in fields},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||||
existing_ids = [
|
existing_ids = [
|
||||||
@@ -51,7 +44,7 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
]
|
]
|
||||||
return set([s for s in data if s not in existing_ids])
|
return set([s for s in data if s not in existing_ids])
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
for mode, items in data.items():
|
for mode, items in data.items():
|
||||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
||||||
@@ -66,7 +59,6 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
return data
|
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
@@ -81,9 +73,15 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def drop(self):
|
async def drop(self) -> None:
|
||||||
""" """
|
"""Drop the collection"""
|
||||||
pass
|
await self._data.drop()
|
||||||
|
|
||||||
|
async def get_by_status_and_ids(
|
||||||
|
self, status: str
|
||||||
|
) -> Union[list[dict[str, Any]], None]:
|
||||||
|
"""Get documents by status and ids"""
|
||||||
|
return self._data.find({"status": status})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@@ -4,7 +4,7 @@ import asyncio
|
|||||||
# import html
|
# import html
|
||||||
# import os
|
# import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Any, TypeVar, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import array
|
import array
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
"""get doc_full data based on id."""
|
"""get doc_full data based on id."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
@@ -211,7 +211,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||||
"""get doc_chunks data based on id"""
|
"""get doc_chunks data based on id"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
@@ -238,14 +238,9 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_by_status_and_ids(
|
async def get_by_status_and_ids(
|
||||||
self, status: str, ids: list[str]
|
self, status: str
|
||||||
) -> Union[list[dict], None]:
|
) -> Union[list[dict[str, Any]], None]:
|
||||||
"""Specifically for llm_response_cache."""
|
"""Specifically for llm_response_cache."""
|
||||||
if ids is not None:
|
|
||||||
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
|
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "status": status}
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
res = await self.db.query(SQL, params, multirows=True)
|
res = await self.db.query(SQL, params, multirows=True)
|
||||||
@@ -270,7 +265,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
return set(keys)
|
return set(keys)
|
||||||
|
|
||||||
################ INSERT METHODS ################
|
################ INSERT METHODS ################
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
list_data = [
|
list_data = [
|
||||||
{
|
{
|
||||||
@@ -328,14 +323,6 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
return None
|
|
||||||
|
|
||||||
async def change_status(self, id: str, status: str):
|
|
||||||
SQL = SQL_TEMPLATES["change_status"].format(
|
|
||||||
table_name=namespace_to_table_name(self.namespace)
|
|
||||||
)
|
|
||||||
params = {"workspace": self.db.workspace, "id": id, "status": status}
|
|
||||||
await self.db.execute(SQL, params)
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
if is_namespace(
|
if is_namespace(
|
||||||
@@ -344,7 +331,6 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
):
|
):
|
||||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# should pass db object to self.db
|
# should pass db object to self.db
|
||||||
@@ -745,7 +731,6 @@ SQL_TEMPLATES = {
|
|||||||
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
||||||
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
||||||
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
||||||
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
|
|
||||||
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.id = :id and a.workspace = :workspace)
|
ON (a.id = :id and a.workspace = :workspace)
|
||||||
|
@@ -30,7 +30,6 @@ from ..base import (
|
|||||||
DocStatus,
|
DocStatus,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
T,
|
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
|
|
||||||
@@ -184,7 +183,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
"""Get doc_full data by id."""
|
"""Get doc_full data by id."""
|
||||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
@@ -214,7 +213,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
|
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||||
"""Get doc_chunks data by id"""
|
"""Get doc_chunks data by id"""
|
||||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
@@ -239,6 +238,14 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_by_status_and_ids(
|
||||||
|
self, status: str
|
||||||
|
) -> Union[list[dict[str, Any]], None]:
|
||||||
|
"""Specifically for llm_response_cache."""
|
||||||
|
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||||
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
|
return await self.db.query(SQL, params, multirows=True)
|
||||||
|
|
||||||
async def all_keys(self) -> list[dict]:
|
async def all_keys(self) -> list[dict]:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
sql = "select workspace,mode,id from lightrag_llm_cache"
|
sql = "select workspace,mode,id from lightrag_llm_cache"
|
||||||
@@ -270,7 +277,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
################ INSERT METHODS ################
|
################ INSERT METHODS ################
|
||||||
async def upsert(self, data: Dict[str, dict]):
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
pass
|
pass
|
||||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||||
@@ -447,7 +454,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
existed = set([element["id"] for element in result])
|
existed = set([element["id"] for element in result])
|
||||||
return set(data) - existed
|
return set(data) - existed
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
result = await self.db.query(sql, params, True)
|
result = await self.db.query(sql, params, True)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Any, TypeVar, Union
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -28,21 +29,11 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||||
return json.loads(data) if data else None
|
return json.loads(data) if data else None
|
||||||
|
|
||||||
async def get_by_ids(self, ids, fields=None):
|
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||||
pipe = self._redis.pipeline()
|
pipe = self._redis.pipeline()
|
||||||
for id in ids:
|
for id in ids:
|
||||||
pipe.get(f"{self.namespace}:{id}")
|
pipe.get(f"{self.namespace}:{id}")
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
if fields:
|
|
||||||
# Filter fields if specified
|
|
||||||
return [
|
|
||||||
{field: value.get(field) for field in fields if field in value}
|
|
||||||
if (value := json.loads(result))
|
|
||||||
else None
|
|
||||||
for result in results
|
|
||||||
]
|
|
||||||
|
|
||||||
return [json.loads(result) if result else None for result in results]
|
return [json.loads(result) if result else None for result in results]
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||||
@@ -54,7 +45,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
||||||
return set(data) - existing_ids
|
return set(data) - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
pipe = self._redis.pipeline()
|
pipe = self._redis.pipeline()
|
||||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||||
@@ -62,9 +53,18 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
for k in data:
|
for k in data:
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
return data
|
|
||||||
|
|
||||||
async def drop(self):
|
async def drop(self) -> None:
|
||||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||||
if keys:
|
if keys:
|
||||||
await self._redis.delete(*keys)
|
await self._redis.delete(*keys)
|
||||||
|
|
||||||
|
async def get_by_status_and_ids(
|
||||||
|
self, status: str,
|
||||||
|
) -> Union[list[dict[str, Any]], None]:
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
for key in await self._redis.keys(f"{self.namespace}:*"):
|
||||||
|
pipe.hgetall(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
return [data for data in results if data.get("status") == status] or None
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Any, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -108,7 +108,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
"""根据 id 获取 doc_full 数据."""
|
"""根据 id 获取 doc_full 数据."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"id": id}
|
params = {"id": id}
|
||||||
@@ -122,16 +122,14 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||||
"""根据 id 获取 doc_chunks 数据"""
|
"""根据 id 获取 doc_chunks 数据"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
)
|
)
|
||||||
# print("get_by_ids:"+SQL)
|
|
||||||
res = await self.db.query(SQL, multirows=True)
|
res = await self.db.query(SQL, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = res # [{"data":i} for i in res]
|
data = res # [{"data":i} for i in res]
|
||||||
# print(data)
|
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -158,7 +156,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
################ INSERT full_doc AND chunks ################
|
################ INSERT full_doc AND chunks ################
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
self._data.update(left_data)
|
self._data.update(left_data)
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
@@ -335,6 +333,12 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
|
async def get_by_status_and_ids(
|
||||||
|
self, status: str
|
||||||
|
) -> Union[list[dict], None]:
|
||||||
|
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||||
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
|
return await self.db.query(SQL, params, multirows=True)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
|
Reference in New Issue
Block a user