Merge branch 'main' into graph-viewer-webui
This commit is contained in:
@@ -1,20 +1,18 @@
|
||||
from enum import Enum
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
Literal,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Optional,
|
||||
Dict,
|
||||
Any,
|
||||
List,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from .utils import EmbeddingFunc
|
||||
|
||||
TextChunkSchema = TypedDict(
|
||||
@@ -45,7 +43,7 @@ class QueryParam:
|
||||
hl_keywords: list[str] = field(default_factory=list)
|
||||
ll_keywords: list[str] = field(default_factory=list)
|
||||
# Conversation history support
|
||||
conversation_history: list[dict] = field(
|
||||
conversation_history: list[dict[str, str]] = field(
|
||||
default_factory=list
|
||||
) # Format: [{"role": "user/assistant", "content": "message"}]
|
||||
history_turns: int = (
|
||||
@@ -56,7 +54,7 @@ class QueryParam:
|
||||
@dataclass
|
||||
class StorageNameSpace:
|
||||
namespace: str
|
||||
global_config: dict
|
||||
global_config: dict[str, Any]
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""commit the storage operations after indexing"""
|
||||
@@ -72,10 +70,10 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
meta_fields: set = field(default_factory=set)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Use 'content' field from value for embedding, use key as id.
|
||||
If embedding_func is None, use 'embedding' field from value
|
||||
"""
|
||||
@@ -83,28 +81,23 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||
class BaseKVStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_ids(
|
||||
self, ids: list[str], fields: Union[set[str], None] = None
|
||||
) -> list[Union[T, None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
"""return un-exist keys"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, T]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self):
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -151,12 +144,12 @@ class BaseGraphStorage(StorageNameSpace):
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||
|
||||
async def get_all_labels(self) -> List[str]:
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> Dict[str, List[Dict]]:
|
||||
) -> dict[str, list[dict]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -173,27 +166,37 @@ class DocStatus(str, Enum):
|
||||
class DocProcessingStatus:
|
||||
"""Document processing status data structure"""
|
||||
|
||||
content_summary: str # First 100 chars of document content
|
||||
content_length: int # Total length of document
|
||||
status: DocStatus # Current processing status
|
||||
created_at: str # ISO format timestamp
|
||||
updated_at: str # ISO format timestamp
|
||||
chunks_count: Optional[int] = None # Number of chunks after splitting
|
||||
error: Optional[str] = None # Error message if failed
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
|
||||
content: str
|
||||
"""Original content of the document"""
|
||||
content_summary: str
|
||||
"""First 100 chars of document content, used for preview"""
|
||||
content_length: int
|
||||
"""Total length of document"""
|
||||
status: DocStatus
|
||||
"""Current processing status"""
|
||||
created_at: str
|
||||
"""ISO format timestamp when document was created"""
|
||||
updated_at: str
|
||||
"""ISO format timestamp when document was last updated"""
|
||||
chunks_count: Optional[int] = None
|
||||
"""Number of chunks after splitting, used for processing"""
|
||||
error: Optional[str] = None
|
||||
"""Error message if failed"""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional metadata"""
|
||||
|
||||
|
||||
class DocStatusStorage(BaseKVStorage):
|
||||
"""Base class for document status storage"""
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all pending documents"""
|
||||
raise NotImplementedError
|
||||
|
@@ -1,63 +1,13 @@
|
||||
"""
|
||||
JsonDocStatus Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
@@ -68,25 +18,20 @@ class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.get(id, None)
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.get(id, {})
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
@@ -96,39 +41,9 @@ class JsonKVStorage(BaseKVStorage):
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
return left_data
|
||||
|
||||
async def drop(self):
|
||||
async def drop(self) -> None:
|
||||
self._data = {}
|
||||
|
||||
async def filter(self, filter_func):
|
||||
"""Filter key-value pairs based on a filter function
|
||||
|
||||
Args:
|
||||
filter_func: The filter function, which takes a value as an argument and returns a boolean value
|
||||
|
||||
Returns:
|
||||
Dict: Key-value pairs that meet the condition
|
||||
"""
|
||||
result = {}
|
||||
async with self._lock:
|
||||
for key, value in self._data.items():
|
||||
if filter_func(value):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete data with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
"""
|
||||
async with self._lock:
|
||||
for id in ids:
|
||||
if id in self._data:
|
||||
del self._data[id]
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")
|
||||
|
@@ -50,7 +50,7 @@ Usage:
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Dict
|
||||
from typing import Any, Union
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
@@ -72,7 +72,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
@@ -85,18 +85,18 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
]
|
||||
)
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status: 0 for status in DocStatus}
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
|
||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
|
||||
|
||||
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all pending documents"""
|
||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
|
||||
|
||||
@@ -104,7 +104,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""Save data to file after indexing"""
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
"""Update or insert document status
|
||||
|
||||
Args:
|
||||
@@ -112,10 +112,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
return data
|
||||
|
||||
async def get_by_id(self, id: str):
|
||||
return self._data.get(id)
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.get(id, {})
|
||||
|
||||
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
||||
"""Get document status by ID"""
|
||||
|
@@ -12,7 +12,7 @@ if not pm.is_installed("motor"):
|
||||
|
||||
from pymongo import MongoClient
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Any, Union, List, Tuple
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import BaseKVStorage, BaseGraphStorage
|
||||
@@ -29,21 +29,11 @@ class MongoKVStorage(BaseKVStorage):
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
||||
|
||||
async def get_by_id(self, id):
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
return list(
|
||||
self._data.find(
|
||||
{"_id": {"$in": ids}},
|
||||
{field: 1 for field in fields},
|
||||
)
|
||||
)
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
existing_ids = [
|
||||
@@ -51,7 +41,7 @@ class MongoKVStorage(BaseKVStorage):
|
||||
]
|
||||
return set([s for s in data if s not in existing_ids])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
for mode, items in data.items():
|
||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
||||
@@ -66,7 +56,6 @@ class MongoKVStorage(BaseKVStorage):
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
data[k]["_id"] = k
|
||||
return data
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
@@ -81,9 +70,9 @@ class MongoKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def drop(self):
|
||||
""" """
|
||||
pass
|
||||
async def drop(self) -> None:
|
||||
"""Drop the collection"""
|
||||
await self._data.drop()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -4,7 +4,7 @@ import asyncio
|
||||
# import html
|
||||
# import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
import numpy as np
|
||||
import array
|
||||
import pipmaster as pm
|
||||
@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
"""get doc_full data based on id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
@@ -191,12 +191,9 @@ class OracleKVStorage(BaseKVStorage):
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
else:
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(SQL, params)
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
@@ -211,7 +208,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""get doc_chunks data based on id"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
@@ -230,29 +227,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
for row in res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str, ids: list[str]
|
||||
) -> Union[list[dict], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
if ids is not None:
|
||||
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
else:
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return res
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
@@ -270,7 +245,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
return set(keys)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
list_data = [
|
||||
{
|
||||
@@ -328,14 +303,6 @@ class OracleKVStorage(BaseKVStorage):
|
||||
}
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
return None
|
||||
|
||||
async def change_status(self, id: str, status: str):
|
||||
SQL = SQL_TEMPLATES["change_status"].format(
|
||||
table_name=namespace_to_table_name(self.namespace)
|
||||
)
|
||||
params = {"workspace": self.db.workspace, "id": id, "status": status}
|
||||
await self.db.execute(SQL, params)
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
@@ -745,7 +712,6 @@ SQL_TEMPLATES = {
|
||||
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
||||
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
||||
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
||||
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
|
||||
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
||||
USING DUAL
|
||||
ON (a.id = :id and a.workspace = :workspace)
|
||||
|
@@ -30,7 +30,6 @@ from ..base import (
|
||||
DocStatus,
|
||||
DocProcessingStatus,
|
||||
BaseGraphStorage,
|
||||
T,
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
|
||||
@@ -184,7 +183,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
"""Get doc_full data by id."""
|
||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
@@ -193,12 +192,9 @@ class PGKVStorage(BaseKVStorage):
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
else:
|
||||
res = await self.db.query(sql, params)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(sql, params)
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
@@ -214,7 +210,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get doc_chunks data by id"""
|
||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
@@ -231,23 +227,15 @@ class PGKVStorage(BaseKVStorage):
|
||||
dict_res[mode] = {}
|
||||
for row in array_res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
return [{k: v} for k, v in dict_res.items()]
|
||||
else:
|
||||
res = await self.db.query(sql, params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(sql, params, multirows=True)
|
||||
|
||||
async def all_keys(self) -> list[dict]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
sql = "select workspace,mode,id from lightrag_llm_cache"
|
||||
res = await self.db.query(sql, multirows=True)
|
||||
return res
|
||||
else:
|
||||
logger.error(
|
||||
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
|
||||
)
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||
"""Filter out duplicated content"""
|
||||
@@ -270,7 +258,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
print(params)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
pass
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||
@@ -447,14 +435,15 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
result = await self.db.query(sql, params, True)
|
||||
if result is None or result == []:
|
||||
return None
|
||||
return {}
|
||||
else:
|
||||
return DocProcessingStatus(
|
||||
content=result[0]["content"],
|
||||
content_length=result[0]["content_length"],
|
||||
content_summary=result[0]["content_summary"],
|
||||
status=result[0]["status"],
|
||||
@@ -483,10 +472,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
result = await self.db.query(sql, params, True)
|
||||
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
||||
# Converting to be a dict
|
||||
return {
|
||||
element["id"]: DocProcessingStatus(
|
||||
content=result[0]["content"],
|
||||
content_summary=element["content_summary"],
|
||||
content_length=element["content_length"],
|
||||
status=element["status"],
|
||||
@@ -518,6 +506,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status)
|
||||
values($1,$2,$3,$4,$5,$6)
|
||||
on conflict(id,workspace) do update set
|
||||
content = EXCLUDED.content,
|
||||
content_summary = EXCLUDED.content_summary,
|
||||
content_length = EXCLUDED.content_length,
|
||||
chunks_count = EXCLUDED.chunks_count,
|
||||
@@ -530,6 +519,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
{
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"content_summary": v["content_summary"],
|
||||
"content_length": v["content_length"],
|
||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
@@ -20,29 +21,15 @@ class RedisKVStorage(BaseKVStorage):
|
||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
logger.info(f"Use Redis as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
return [key.split(":", 1)[-1] for key in keys]
|
||||
|
||||
async def get_by_id(self, id):
|
||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
pipe = self._redis.pipeline()
|
||||
for id in ids:
|
||||
pipe.get(f"{self.namespace}:{id}")
|
||||
results = await pipe.execute()
|
||||
|
||||
if fields:
|
||||
# Filter fields if specified
|
||||
return [
|
||||
{field: value.get(field) for field in fields if field in value}
|
||||
if (value := json.loads(result))
|
||||
else None
|
||||
for result in results
|
||||
]
|
||||
|
||||
return [json.loads(result) if result else None for result in results]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
@@ -54,7 +41,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
||||
return set(data) - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
pipe = self._redis.pipeline()
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||
@@ -62,9 +49,8 @@ class RedisKVStorage(BaseKVStorage):
|
||||
|
||||
for k in data:
|
||||
data[k]["_id"] = k
|
||||
return data
|
||||
|
||||
async def drop(self):
|
||||
async def drop(self) -> None:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
@@ -108,33 +108,20 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
"""Fetch doc_full data by id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
data = res # {"data":res}
|
||||
# print (data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(SQL, params)
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Fetch doc_chunks data by id"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
return await self.db.query(SQL, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
@@ -158,7 +145,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
return data
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, Any]) -> None:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
@@ -335,6 +322,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
@@ -4,17 +4,15 @@ from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast, Dict
|
||||
|
||||
from typing import Any, Callable, Coroutine, Optional, Type, Union, cast
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
# local_query,global_query,hybrid_query,
|
||||
kg_query,
|
||||
naive_query,
|
||||
mix_kg_vector_query,
|
||||
extract_keywords_only,
|
||||
kg_query,
|
||||
kg_query_with_keywords,
|
||||
mix_kg_vector_query,
|
||||
naive_query,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
@@ -24,15 +22,16 @@ from .utils import (
|
||||
convert_response_to_json,
|
||||
logger,
|
||||
set_logger,
|
||||
statistic_data,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
StorageNameSpace,
|
||||
QueryParam,
|
||||
DocProcessingStatus,
|
||||
DocStatus,
|
||||
DocStatusStorage,
|
||||
QueryParam,
|
||||
StorageNameSpace,
|
||||
)
|
||||
|
||||
from .namespace import NameSpace, make_namespace
|
||||
@@ -176,15 +175,26 @@ class LightRAG:
|
||||
enable_llm_cache_for_entity_extract: bool = True
|
||||
|
||||
# extension
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||
convert_response_to_json
|
||||
)
|
||||
|
||||
# Add new field for document status storage type
|
||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||
|
||||
# Custom Chunking Function
|
||||
chunking_func: callable = chunking_by_token_size
|
||||
chunking_func_kwargs: dict = field(default_factory=dict)
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
Optional[str],
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
str,
|
||||
],
|
||||
list[dict[str, Any]],
|
||||
] = chunking_by_token_size
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
@@ -245,19 +255,19 @@ class LightRAG:
|
||||
####
|
||||
# add embedding func by walter
|
||||
####
|
||||
self.full_docs = self.key_string_value_json_storage_cls(
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||
),
|
||||
@@ -281,7 +291,7 @@ class LightRAG:
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb = self.vector_db_storage_cls(
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||
),
|
||||
@@ -310,7 +320,7 @@ class LightRAG:
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
self.doc_status = self.doc_status_storage_cls(
|
||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||
global_config=global_config,
|
||||
embedding_func=None,
|
||||
@@ -351,17 +361,12 @@ class LightRAG:
|
||||
storage.db = db_client
|
||||
|
||||
def insert(
|
||||
self, string_or_strings, split_by_character=None, split_by_character_only=False
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
||||
)
|
||||
|
||||
async def ainsert(
|
||||
self, string_or_strings, split_by_character=None, split_by_character_only=False
|
||||
):
|
||||
"""Insert documents with checkpoint support
|
||||
"""Sync Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
@@ -370,154 +375,30 @@ class LightRAG:
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
||||
)
|
||||
|
||||
# 1. Remove duplicate contents from the list
|
||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
||||
async def ainsert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Async Insert documents with checkpoint support
|
||||
|
||||
# 2. Generate document IDs and initial status
|
||||
new_docs = {
|
||||
compute_mdhash_id(content, prefix="doc-"): {
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
for content in unique_contents
|
||||
}
|
||||
|
||||
# 3. Filter out already processed documents
|
||||
# _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
|
||||
_add_doc_keys = set()
|
||||
for doc_id in new_docs.keys():
|
||||
current_doc = await self.doc_status.get_by_id(doc_id)
|
||||
|
||||
if current_doc is None:
|
||||
_add_doc_keys.add(doc_id)
|
||||
continue # skip to the next doc_id
|
||||
|
||||
status = None
|
||||
if isinstance(current_doc, dict):
|
||||
status = current_doc["status"]
|
||||
else:
|
||||
status = current_doc.status
|
||||
|
||||
if status == DocStatus.FAILED:
|
||||
_add_doc_keys.add(doc_id)
|
||||
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
|
||||
if not new_docs:
|
||||
logger.info("All documents have been processed or are duplicates")
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(new_docs)} new unique documents")
|
||||
|
||||
# Process documents in batches
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
for i in range(0, len(new_docs), batch_size):
|
||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
||||
|
||||
for doc_id, doc in tqdm_async(
|
||||
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
||||
):
|
||||
try:
|
||||
# Update status to processing
|
||||
doc_status = {
|
||||
"content_summary": doc["content_summary"],
|
||||
"content_length": doc["content_length"],
|
||||
"status": DocStatus.PROCESSING,
|
||||
"created_at": doc["created_at"],
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
await self.doc_status.upsert({doc_id: doc_status})
|
||||
|
||||
# Generate chunks from document
|
||||
chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
doc["content"],
|
||||
split_by_character=split_by_character,
|
||||
split_by_character_only=split_by_character_only,
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
**self.chunking_func_kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
# Update status with chunks information
|
||||
doc_status.update(
|
||||
{
|
||||
"chunks_count": len(chunks),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
await self.doc_status.upsert({doc_id: doc_status})
|
||||
|
||||
try:
|
||||
# Store chunks in vector database
|
||||
await self.chunks_vdb.upsert(chunks)
|
||||
|
||||
# Extract and store entities and relationships
|
||||
maybe_new_kg = await extract_entities(
|
||||
chunks,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
|
||||
if maybe_new_kg is None:
|
||||
raise Exception(
|
||||
"Failed to extract entities and relationships"
|
||||
)
|
||||
|
||||
self.chunk_entity_relation_graph = maybe_new_kg
|
||||
|
||||
# Store original document and chunks
|
||||
await self.full_docs.upsert(
|
||||
{doc_id: {"content": doc["content"]}}
|
||||
)
|
||||
await self.text_chunks.upsert(chunks)
|
||||
|
||||
# Update status to processed
|
||||
doc_status.update(
|
||||
{
|
||||
"status": DocStatus.PROCESSED,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
await self.doc_status.upsert({doc_id: doc_status})
|
||||
|
||||
except Exception as e:
|
||||
# Mark as failed if any step fails
|
||||
doc_status.update(
|
||||
{
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
await self.doc_status.upsert({doc_id: doc_status})
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
continue
|
||||
else:
|
||||
# Only update index when processing succeeds
|
||||
await self._insert_done()
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
await self.apipeline_enqueue_documents(string_or_strings)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
|
||||
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
||||
loop = always_get_an_event_loop()
|
||||
@@ -586,10 +467,14 @@ class LightRAG:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def apipeline_process_documents(self, string_or_strings):
|
||||
"""Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
1. Remove duplicate contents from the list
|
||||
2. Generate document IDs and initial status
|
||||
3. Filter out already processed documents
|
||||
4. Enqueue document in status
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
@@ -597,183 +482,187 @@ class LightRAG:
|
||||
# 1. Remove duplicate contents from the list
|
||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
||||
|
||||
logger.info(
|
||||
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
|
||||
)
|
||||
|
||||
# 2. Generate document IDs and initial status
|
||||
new_docs = {
|
||||
new_docs: dict[str, Any] = {
|
||||
compute_mdhash_id(content, prefix="doc-"): {
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": None,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
for content in unique_contents
|
||||
}
|
||||
|
||||
# 3. Filter out already processed documents
|
||||
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
if len(_not_stored_doc_keys) < len(new_docs):
|
||||
logger.info(
|
||||
f"Skipping {len(new_docs) - len(_not_stored_doc_keys)} already existing documents"
|
||||
)
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
|
||||
add_doc_keys: set[str] = set()
|
||||
# Get docs ids
|
||||
in_process_keys = list(new_docs.keys())
|
||||
# Get in progress docs ids
|
||||
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
|
||||
# Exclude already in process
|
||||
add_doc_keys = new_docs.keys() - excluded_ids
|
||||
# Filter
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
||||
|
||||
if not new_docs:
|
||||
logger.info("All documents have been processed or are duplicates")
|
||||
return None
|
||||
return
|
||||
|
||||
# 4. Store original document
|
||||
for doc_id, doc in new_docs.items():
|
||||
await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
|
||||
await self.full_docs.change_status(doc_id, DocStatus.PENDING)
|
||||
# 4. Store status document
|
||||
await self.doc_status.upsert(new_docs)
|
||||
logger.info(f"Stored {len(new_docs)} new unique documents")
|
||||
|
||||
async def apipeline_process_chunks(self):
|
||||
"""Get pendding documents, split into chunks,insert chunks"""
|
||||
# 1. get all pending and failed documents
|
||||
_todo_doc_keys = []
|
||||
_failed_doc = await self.full_docs.get_by_status_and_ids(
|
||||
status=DocStatus.FAILED, ids=None
|
||||
)
|
||||
_pendding_doc = await self.full_docs.get_by_status_and_ids(
|
||||
status=DocStatus.PENDING, ids=None
|
||||
)
|
||||
if _failed_doc:
|
||||
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
|
||||
if _pendding_doc:
|
||||
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
|
||||
if not _todo_doc_keys:
|
||||
logger.info("All documents have been processed or are duplicates")
|
||||
return None
|
||||
else:
|
||||
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
|
||||
async def apipeline_process_enqueue_documents(
|
||||
self,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Process pending documents by splitting them into chunks, processing
|
||||
each chunk for entity and relation extraction, and updating the
|
||||
document status.
|
||||
|
||||
new_docs = {
|
||||
doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
|
||||
}
|
||||
1. Get all pending and failed documents
|
||||
2. Split document content into chunks
|
||||
3. Process each chunk for entity and relation extraction
|
||||
4. Update the document status
|
||||
"""
|
||||
# 1. get all pending and failed documents
|
||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||
|
||||
# Fetch failed documents
|
||||
failed_docs = await self.doc_status.get_failed_docs()
|
||||
to_process_docs.update(failed_docs)
|
||||
|
||||
pending_docs = await self.doc_status.get_pending_docs()
|
||||
to_process_docs.update(pending_docs)
|
||||
|
||||
if not to_process_docs:
|
||||
logger.info("All documents have been processed or are duplicates")
|
||||
return
|
||||
|
||||
to_process_docs_ids = list(to_process_docs.keys())
|
||||
|
||||
# Get allready processed documents (text chunks and full docs)
|
||||
text_chunks_processed_doc_ids = await self.text_chunks.filter_keys(
|
||||
to_process_docs_ids
|
||||
)
|
||||
full_docs_processed_doc_ids = await self.full_docs.filter_keys(
|
||||
to_process_docs_ids
|
||||
)
|
||||
|
||||
# 2. split docs into chunks, insert chunks, update doc status
|
||||
chunk_cnt = 0
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
for i in range(0, len(new_docs), batch_size):
|
||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
||||
for doc_id, doc in tqdm_async(
|
||||
batch_docs.items(),
|
||||
desc=f"Level 1 - Spliting doc in batch {i // batch_size + 1}",
|
||||
batch_docs_list = [
|
||||
list(to_process_docs.items())[i : i + batch_size]
|
||||
for i in range(0, len(to_process_docs), batch_size)
|
||||
]
|
||||
|
||||
# 3. iterate over batches
|
||||
tasks: dict[str, list[Coroutine[Any, Any, None]]] = {}
|
||||
for batch_idx, ids_doc_processing_status in tqdm_async(
|
||||
enumerate(batch_docs_list),
|
||||
desc="Process Batches",
|
||||
):
|
||||
# 4. iterate over batch
|
||||
for id_doc_processing_status in tqdm_async(
|
||||
ids_doc_processing_status,
|
||||
desc=f"Process Batch {batch_idx}",
|
||||
):
|
||||
try:
|
||||
# Generate chunks from document
|
||||
chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
"status": DocStatus.PENDING,
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
doc["content"],
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
chunk_cnt += len(chunks)
|
||||
await self.text_chunks.upsert(chunks)
|
||||
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSING)
|
||||
|
||||
try:
|
||||
# Store chunks in vector database
|
||||
await self.chunks_vdb.upsert(chunks)
|
||||
# Update doc status
|
||||
await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
|
||||
except Exception as e:
|
||||
# Mark as failed if any step fails
|
||||
await self.full_docs.change_status(doc_id, DocStatus.FAILED)
|
||||
raise e
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
continue
|
||||
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
|
||||
|
||||
async def apipeline_process_extract_graph(self):
|
||||
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
|
||||
# 1. get all pending and failed chunks
|
||||
_todo_chunk_keys = []
|
||||
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
|
||||
status=DocStatus.FAILED, ids=None
|
||||
)
|
||||
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
|
||||
status=DocStatus.PENDING, ids=None
|
||||
)
|
||||
if _failed_chunks:
|
||||
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
|
||||
if _pendding_chunks:
|
||||
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
|
||||
if not _todo_chunk_keys:
|
||||
logger.info("All chunks have been processed or are duplicates")
|
||||
return None
|
||||
|
||||
# Process documents in batches
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
|
||||
semaphore = asyncio.Semaphore(
|
||||
batch_size
|
||||
) # Control the number of tasks that are processed simultaneously
|
||||
|
||||
async def process_chunk(chunk_id):
|
||||
async with semaphore:
|
||||
chunks = {
|
||||
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
|
||||
}
|
||||
# Extract and store entities and relationships
|
||||
try:
|
||||
maybe_new_kg = await extract_entities(
|
||||
chunks,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
if maybe_new_kg is None:
|
||||
logger.info("No entities or relationships extracted!")
|
||||
# Update status to processed
|
||||
await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract entities and relationships")
|
||||
# Mark as failed if any step fails
|
||||
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
|
||||
raise e
|
||||
|
||||
with tqdm_async(
|
||||
total=len(_todo_chunk_keys),
|
||||
desc="\nLevel 1 - Processing chunks",
|
||||
unit="chunk",
|
||||
position=0,
|
||||
) as progress:
|
||||
tasks = []
|
||||
for chunk_id in _todo_chunk_keys:
|
||||
task = asyncio.create_task(process_chunk(chunk_id))
|
||||
tasks.append(task)
|
||||
|
||||
for future in asyncio.as_completed(tasks):
|
||||
await future
|
||||
progress.update(1)
|
||||
progress.set_postfix(
|
||||
id_doc, status_doc = id_doc_processing_status
|
||||
# Update status in processing
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
"LLM call": statistic_data["llm_call"],
|
||||
"LLM cache": statistic_data["llm_cache"],
|
||||
id_doc: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": id_doc_processing_status,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
|
||||
# Ensure all indexes are updated after each document
|
||||
await self._insert_done()
|
||||
# Ensure chunk insertion and graph processing happen sequentially, not in parallel
|
||||
await self.chunks_vdb.upsert(chunks)
|
||||
await self._process_entity_relation_graph(chunks)
|
||||
|
||||
tasks[id_doc] = []
|
||||
# Check if document already processed the doc
|
||||
if id_doc not in full_docs_processed_doc_ids:
|
||||
tasks[id_doc].append(
|
||||
self.full_docs.upsert({id_doc: {"content": status_doc.content}})
|
||||
)
|
||||
|
||||
# Check if chunks already processed the doc
|
||||
if id_doc not in text_chunks_processed_doc_ids:
|
||||
tasks[id_doc].append(self.text_chunks.upsert(chunks))
|
||||
|
||||
# Process document (text chunks and full docs) in parallel
|
||||
for id_doc_processing_status, task in tasks.items():
|
||||
try:
|
||||
await asyncio.gather(*task)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
id_doc_processing_status: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
await self._insert_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process document {id_doc_processing_status}: {str(e)}"
|
||||
)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
id_doc_processing_status: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||
try:
|
||||
new_kg = await extract_entities(
|
||||
chunk,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
if new_kg is None:
|
||||
logger.info("No entities or relationships extracted!")
|
||||
else:
|
||||
self.chunk_entity_relation_graph = new_kg
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract entities and relationships")
|
||||
raise e
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
@@ -1169,7 +1058,7 @@ class LightRAG:
|
||||
return content
|
||||
return content[:max_length] + "..."
|
||||
|
||||
async def get_processing_status(self) -> Dict[str, int]:
|
||||
async def get_processing_status(self) -> dict[str, int]:
|
||||
"""Get current document processing status counts
|
||||
|
||||
Returns:
|
||||
|
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from collections import Counter, defaultdict
|
||||
from .utils import (
|
||||
logger,
|
||||
@@ -36,15 +36,14 @@ import time
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str,
|
||||
split_by_character=None,
|
||||
split_by_character_only=False,
|
||||
overlap_token_size=128,
|
||||
max_token_size=1024,
|
||||
tiktoken_model="gpt-4o",
|
||||
**kwargs,
|
||||
):
|
||||
split_by_character: Union[str, None] = None,
|
||||
split_by_character_only: bool = False,
|
||||
overlap_token_size: int = 128,
|
||||
max_token_size: int = 1024,
|
||||
tiktoken_model: str = "gpt-4o",
|
||||
) -> list[dict[str, Any]]:
|
||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||
results = []
|
||||
results: list[dict[str, Any]] = []
|
||||
if split_by_character:
|
||||
raw_chunks = content.split(split_by_character)
|
||||
new_chunks = []
|
||||
@@ -568,7 +567,7 @@ async def kg_query(
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
@@ -777,7 +776,7 @@ async def mix_kg_vector_query(
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
@@ -969,7 +968,7 @@ async def _build_query_context(
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
):
|
||||
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
||||
@@ -1052,7 +1051,7 @@ async def _get_node_data(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
):
|
||||
# get similar entities
|
||||
@@ -1145,7 +1144,7 @@ async def _get_node_data(
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
text_units = [
|
||||
@@ -1268,7 +1267,7 @@ async def _get_edge_data(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
):
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
@@ -1421,7 +1420,7 @@ async def _find_most_related_entities_from_relationships(
|
||||
async def _find_related_text_unit_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
text_units = [
|
||||
@@ -1496,7 +1495,7 @@ def combine_contexts(entities, relationships, sources):
|
||||
async def naive_query(
|
||||
query,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
@@ -1599,7 +1598,7 @@ async def kg_query_with_keywords(
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
|
@@ -98,7 +98,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
return None
|
||||
|
||||
|
||||
def convert_response_to_json(response: str) -> dict:
|
||||
def convert_response_to_json(response: str) -> dict[str, Any]:
|
||||
json_str = locate_json_string_body_from_string(response)
|
||||
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
||||
try:
|
||||
|
Reference in New Issue
Block a user