Merge branch 'main' into graph-viewer-webui
This commit is contained in:
@@ -1,63 +1,13 @@
|
||||
"""
|
||||
JsonDocStatus Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
@@ -68,25 +18,20 @@ class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.get(id, None)
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.get(id, {})
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
@@ -96,39 +41,9 @@ class JsonKVStorage(BaseKVStorage):
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
return left_data
|
||||
|
||||
async def drop(self):
|
||||
async def drop(self) -> None:
|
||||
self._data = {}
|
||||
|
||||
async def filter(self, filter_func):
|
||||
"""Filter key-value pairs based on a filter function
|
||||
|
||||
Args:
|
||||
filter_func: The filter function, which takes a value as an argument and returns a boolean value
|
||||
|
||||
Returns:
|
||||
Dict: Key-value pairs that meet the condition
|
||||
"""
|
||||
result = {}
|
||||
async with self._lock:
|
||||
for key, value in self._data.items():
|
||||
if filter_func(value):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete data with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
"""
|
||||
async with self._lock:
|
||||
for id in ids:
|
||||
if id in self._data:
|
||||
del self._data[id]
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")
|
||||
|
@@ -50,7 +50,7 @@ Usage:
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Dict
|
||||
from typing import Any, Union
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
@@ -72,7 +72,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
@@ -85,18 +85,18 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
]
|
||||
)
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status: 0 for status in DocStatus}
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
|
||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
|
||||
|
||||
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all pending documents"""
|
||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
|
||||
|
||||
@@ -104,7 +104,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""Save data to file after indexing"""
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
"""Update or insert document status
|
||||
|
||||
Args:
|
||||
@@ -112,10 +112,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
return data
|
||||
|
||||
async def get_by_id(self, id: str):
|
||||
return self._data.get(id)
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.get(id, {})
|
||||
|
||||
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
||||
"""Get document status by ID"""
|
||||
|
@@ -12,7 +12,7 @@ if not pm.is_installed("motor"):
|
||||
|
||||
from pymongo import MongoClient
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Any, Union, List, Tuple
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import BaseKVStorage, BaseGraphStorage
|
||||
@@ -29,21 +29,11 @@ class MongoKVStorage(BaseKVStorage):
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
||||
|
||||
async def get_by_id(self, id):
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
return list(
|
||||
self._data.find(
|
||||
{"_id": {"$in": ids}},
|
||||
{field: 1 for field in fields},
|
||||
)
|
||||
)
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
existing_ids = [
|
||||
@@ -51,7 +41,7 @@ class MongoKVStorage(BaseKVStorage):
|
||||
]
|
||||
return set([s for s in data if s not in existing_ids])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
for mode, items in data.items():
|
||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
||||
@@ -66,7 +56,6 @@ class MongoKVStorage(BaseKVStorage):
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
data[k]["_id"] = k
|
||||
return data
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
@@ -81,9 +70,9 @@ class MongoKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def drop(self):
|
||||
""" """
|
||||
pass
|
||||
async def drop(self) -> None:
|
||||
"""Drop the collection"""
|
||||
await self._data.drop()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -4,7 +4,7 @@ import asyncio
|
||||
# import html
|
||||
# import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
import numpy as np
|
||||
import array
|
||||
import pipmaster as pm
|
||||
@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
"""get doc_full data based on id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
@@ -191,12 +191,9 @@ class OracleKVStorage(BaseKVStorage):
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
else:
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(SQL, params)
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
@@ -211,7 +208,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""get doc_chunks data based on id"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
@@ -230,29 +227,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
for row in res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str, ids: list[str]
|
||||
) -> Union[list[dict], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
if ids is not None:
|
||||
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
else:
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return res
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
@@ -270,7 +245,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
return set(keys)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
list_data = [
|
||||
{
|
||||
@@ -328,14 +303,6 @@ class OracleKVStorage(BaseKVStorage):
|
||||
}
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
return None
|
||||
|
||||
async def change_status(self, id: str, status: str):
|
||||
SQL = SQL_TEMPLATES["change_status"].format(
|
||||
table_name=namespace_to_table_name(self.namespace)
|
||||
)
|
||||
params = {"workspace": self.db.workspace, "id": id, "status": status}
|
||||
await self.db.execute(SQL, params)
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
@@ -745,7 +712,6 @@ SQL_TEMPLATES = {
|
||||
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
||||
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
||||
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
||||
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
|
||||
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
||||
USING DUAL
|
||||
ON (a.id = :id and a.workspace = :workspace)
|
||||
|
@@ -30,7 +30,6 @@ from ..base import (
|
||||
DocStatus,
|
||||
DocProcessingStatus,
|
||||
BaseGraphStorage,
|
||||
T,
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
|
||||
@@ -184,7 +183,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
"""Get doc_full data by id."""
|
||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
@@ -193,12 +192,9 @@ class PGKVStorage(BaseKVStorage):
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
else:
|
||||
res = await self.db.query(sql, params)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(sql, params)
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
@@ -214,7 +210,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get doc_chunks data by id"""
|
||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
@@ -231,23 +227,15 @@ class PGKVStorage(BaseKVStorage):
|
||||
dict_res[mode] = {}
|
||||
for row in array_res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
return [{k: v} for k, v in dict_res.items()]
|
||||
else:
|
||||
res = await self.db.query(sql, params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(sql, params, multirows=True)
|
||||
|
||||
async def all_keys(self) -> list[dict]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
sql = "select workspace,mode,id from lightrag_llm_cache"
|
||||
res = await self.db.query(sql, multirows=True)
|
||||
return res
|
||||
else:
|
||||
logger.error(
|
||||
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
|
||||
)
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||
"""Filter out duplicated content"""
|
||||
@@ -270,7 +258,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
print(params)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
pass
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||
@@ -447,14 +435,15 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
result = await self.db.query(sql, params, True)
|
||||
if result is None or result == []:
|
||||
return None
|
||||
return {}
|
||||
else:
|
||||
return DocProcessingStatus(
|
||||
content=result[0]["content"],
|
||||
content_length=result[0]["content_length"],
|
||||
content_summary=result[0]["content_summary"],
|
||||
status=result[0]["status"],
|
||||
@@ -483,10 +472,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
result = await self.db.query(sql, params, True)
|
||||
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
||||
# Converting to be a dict
|
||||
return {
|
||||
element["id"]: DocProcessingStatus(
|
||||
content=result[0]["content"],
|
||||
content_summary=element["content_summary"],
|
||||
content_length=element["content_length"],
|
||||
status=element["status"],
|
||||
@@ -518,6 +506,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status)
|
||||
values($1,$2,$3,$4,$5,$6)
|
||||
on conflict(id,workspace) do update set
|
||||
content = EXCLUDED.content,
|
||||
content_summary = EXCLUDED.content_summary,
|
||||
content_length = EXCLUDED.content_length,
|
||||
chunks_count = EXCLUDED.chunks_count,
|
||||
@@ -530,6 +519,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
{
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"content_summary": v["content_summary"],
|
||||
"content_length": v["content_length"],
|
||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
@@ -20,29 +21,15 @@ class RedisKVStorage(BaseKVStorage):
|
||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
logger.info(f"Use Redis as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
return [key.split(":", 1)[-1] for key in keys]
|
||||
|
||||
async def get_by_id(self, id):
|
||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
pipe = self._redis.pipeline()
|
||||
for id in ids:
|
||||
pipe.get(f"{self.namespace}:{id}")
|
||||
results = await pipe.execute()
|
||||
|
||||
if fields:
|
||||
# Filter fields if specified
|
||||
return [
|
||||
{field: value.get(field) for field in fields if field in value}
|
||||
if (value := json.loads(result))
|
||||
else None
|
||||
for result in results
|
||||
]
|
||||
|
||||
return [json.loads(result) if result else None for result in results]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
@@ -54,7 +41,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
||||
return set(data) - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
pipe = self._redis.pipeline()
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||
@@ -62,9 +49,8 @@ class RedisKVStorage(BaseKVStorage):
|
||||
|
||||
for k in data:
|
||||
data[k]["_id"] = k
|
||||
return data
|
||||
|
||||
async def drop(self):
|
||||
async def drop(self) -> None:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
@@ -108,33 +108,20 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
"""Fetch doc_full data by id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
data = res # {"data":res}
|
||||
# print (data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(SQL, params)
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Fetch doc_chunks data by id"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(SQL, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
@@ -158,7 +145,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
return data
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
@@ -335,6 +322,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
Reference in New Issue
Block a user