|
|
|
@@ -6,8 +6,10 @@ import configparser
|
|
|
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
|
|
|
|
from asyncio import Lock
|
|
|
|
|
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
|
|
|
|
|
|
|
|
|
|
from lightrag.kg import STORAGE_ENV_REQUIREMENTS, STORAGE_IMPLEMENTATIONS, STORAGES
|
|
|
|
|
|
|
|
|
|
from .base import (
|
|
|
|
|
BaseGraphStorage,
|
|
|
|
|
BaseKVStorage,
|
|
|
|
@@ -32,8 +34,10 @@ from .operate import (
|
|
|
|
|
from .prompt import GRAPH_FIELD_SEP
|
|
|
|
|
from .utils import (
|
|
|
|
|
EmbeddingFunc,
|
|
|
|
|
always_get_an_event_loop,
|
|
|
|
|
compute_mdhash_id,
|
|
|
|
|
convert_response_to_json,
|
|
|
|
|
lazy_external_import,
|
|
|
|
|
limit_async_func_call,
|
|
|
|
|
logger,
|
|
|
|
|
set_logger,
|
|
|
|
@@ -43,210 +47,22 @@ from .utils import (
|
|
|
|
|
config = configparser.ConfigParser()
|
|
|
|
|
config.read("config.ini", "utf-8")
|
|
|
|
|
|
|
|
|
|
# Storage type and implementation compatibility validation table
|
|
|
|
|
STORAGE_IMPLEMENTATIONS = {
|
|
|
|
|
"KV_STORAGE": {
|
|
|
|
|
"implementations": [
|
|
|
|
|
"JsonKVStorage",
|
|
|
|
|
"MongoKVStorage",
|
|
|
|
|
"RedisKVStorage",
|
|
|
|
|
"TiDBKVStorage",
|
|
|
|
|
"PGKVStorage",
|
|
|
|
|
"OracleKVStorage",
|
|
|
|
|
],
|
|
|
|
|
"required_methods": ["get_by_id", "upsert"],
|
|
|
|
|
},
|
|
|
|
|
"GRAPH_STORAGE": {
|
|
|
|
|
"implementations": [
|
|
|
|
|
"NetworkXStorage",
|
|
|
|
|
"Neo4JStorage",
|
|
|
|
|
"MongoGraphStorage",
|
|
|
|
|
"TiDBGraphStorage",
|
|
|
|
|
"AGEStorage",
|
|
|
|
|
"GremlinStorage",
|
|
|
|
|
"PGGraphStorage",
|
|
|
|
|
"OracleGraphStorage",
|
|
|
|
|
],
|
|
|
|
|
"required_methods": ["upsert_node", "upsert_edge"],
|
|
|
|
|
},
|
|
|
|
|
"VECTOR_STORAGE": {
|
|
|
|
|
"implementations": [
|
|
|
|
|
"NanoVectorDBStorage",
|
|
|
|
|
"MilvusVectorDBStorage",
|
|
|
|
|
"ChromaVectorDBStorage",
|
|
|
|
|
"TiDBVectorDBStorage",
|
|
|
|
|
"PGVectorStorage",
|
|
|
|
|
"FaissVectorDBStorage",
|
|
|
|
|
"QdrantVectorDBStorage",
|
|
|
|
|
"OracleVectorDBStorage",
|
|
|
|
|
"MongoVectorDBStorage",
|
|
|
|
|
],
|
|
|
|
|
"required_methods": ["query", "upsert"],
|
|
|
|
|
},
|
|
|
|
|
"DOC_STATUS_STORAGE": {
|
|
|
|
|
"implementations": [
|
|
|
|
|
"JsonDocStatusStorage",
|
|
|
|
|
"PGDocStatusStorage",
|
|
|
|
|
"PGDocStatusStorage",
|
|
|
|
|
"MongoDocStatusStorage",
|
|
|
|
|
],
|
|
|
|
|
"required_methods": ["get_docs_by_status"],
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Storage implementation environment variable without default value
|
|
|
|
|
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
|
|
|
# KV Storage Implementations
|
|
|
|
|
"JsonKVStorage": [],
|
|
|
|
|
"MongoKVStorage": [],
|
|
|
|
|
"RedisKVStorage": ["REDIS_URI"],
|
|
|
|
|
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
|
|
|
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
|
|
|
"OracleKVStorage": [
|
|
|
|
|
"ORACLE_DSN",
|
|
|
|
|
"ORACLE_USER",
|
|
|
|
|
"ORACLE_PASSWORD",
|
|
|
|
|
"ORACLE_CONFIG_DIR",
|
|
|
|
|
],
|
|
|
|
|
# Graph Storage Implementations
|
|
|
|
|
"NetworkXStorage": [],
|
|
|
|
|
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
|
|
|
|
"MongoGraphStorage": [],
|
|
|
|
|
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
|
|
|
"AGEStorage": [
|
|
|
|
|
"AGE_POSTGRES_DB",
|
|
|
|
|
"AGE_POSTGRES_USER",
|
|
|
|
|
"AGE_POSTGRES_PASSWORD",
|
|
|
|
|
],
|
|
|
|
|
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
|
|
|
|
"PGGraphStorage": [
|
|
|
|
|
"POSTGRES_USER",
|
|
|
|
|
"POSTGRES_PASSWORD",
|
|
|
|
|
"POSTGRES_DATABASE",
|
|
|
|
|
],
|
|
|
|
|
"OracleGraphStorage": [
|
|
|
|
|
"ORACLE_DSN",
|
|
|
|
|
"ORACLE_USER",
|
|
|
|
|
"ORACLE_PASSWORD",
|
|
|
|
|
"ORACLE_CONFIG_DIR",
|
|
|
|
|
],
|
|
|
|
|
# Vector Storage Implementations
|
|
|
|
|
"NanoVectorDBStorage": [],
|
|
|
|
|
"MilvusVectorDBStorage": [],
|
|
|
|
|
"ChromaVectorDBStorage": [],
|
|
|
|
|
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
|
|
|
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
|
|
|
"FaissVectorDBStorage": [],
|
|
|
|
|
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
|
|
|
|
"OracleVectorDBStorage": [
|
|
|
|
|
"ORACLE_DSN",
|
|
|
|
|
"ORACLE_USER",
|
|
|
|
|
"ORACLE_PASSWORD",
|
|
|
|
|
"ORACLE_CONFIG_DIR",
|
|
|
|
|
],
|
|
|
|
|
"MongoVectorDBStorage": [],
|
|
|
|
|
# Document Status Storage Implementations
|
|
|
|
|
"JsonDocStatusStorage": [],
|
|
|
|
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
|
|
|
"MongoDocStatusStorage": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Storage implementation module mapping
|
|
|
|
|
STORAGES = {
|
|
|
|
|
"NetworkXStorage": ".kg.networkx_impl",
|
|
|
|
|
"JsonKVStorage": ".kg.json_kv_impl",
|
|
|
|
|
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
|
|
|
|
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
|
|
|
|
"Neo4JStorage": ".kg.neo4j_impl",
|
|
|
|
|
"OracleKVStorage": ".kg.oracle_impl",
|
|
|
|
|
"OracleGraphStorage": ".kg.oracle_impl",
|
|
|
|
|
"OracleVectorDBStorage": ".kg.oracle_impl",
|
|
|
|
|
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
|
|
|
|
"MongoKVStorage": ".kg.mongo_impl",
|
|
|
|
|
"MongoDocStatusStorage": ".kg.mongo_impl",
|
|
|
|
|
"MongoGraphStorage": ".kg.mongo_impl",
|
|
|
|
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
|
|
|
|
"RedisKVStorage": ".kg.redis_impl",
|
|
|
|
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
|
|
|
|
"TiDBKVStorage": ".kg.tidb_impl",
|
|
|
|
|
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
|
|
|
|
"TiDBGraphStorage": ".kg.tidb_impl",
|
|
|
|
|
"PGKVStorage": ".kg.postgres_impl",
|
|
|
|
|
"PGVectorStorage": ".kg.postgres_impl",
|
|
|
|
|
"AGEStorage": ".kg.age_impl",
|
|
|
|
|
"PGGraphStorage": ".kg.postgres_impl",
|
|
|
|
|
"GremlinStorage": ".kg.gremlin_impl",
|
|
|
|
|
"PGDocStatusStorage": ".kg.postgres_impl",
|
|
|
|
|
"FaissVectorDBStorage": ".kg.faiss_impl",
|
|
|
|
|
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
|
|
|
|
"""Lazily import a class from an external module based on the package of the caller."""
|
|
|
|
|
# Get the caller's module and package
|
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
|
|
caller_frame = inspect.currentframe().f_back
|
|
|
|
|
module = inspect.getmodule(caller_frame)
|
|
|
|
|
package = module.__package__ if module else None
|
|
|
|
|
|
|
|
|
|
def import_class(*args: Any, **kwargs: Any):
|
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
|
|
module = importlib.import_module(module_name, package=package)
|
|
|
|
|
cls = getattr(module, class_name)
|
|
|
|
|
return cls(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return import_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
|
|
|
"""
|
|
|
|
|
Ensure that there is always an event loop available.
|
|
|
|
|
|
|
|
|
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
|
|
|
it creates a new event loop and sets it as the current event loop.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Try to get the current event loop
|
|
|
|
|
current_loop = asyncio.get_event_loop()
|
|
|
|
|
if current_loop.is_closed():
|
|
|
|
|
raise RuntimeError("Event loop is closed.")
|
|
|
|
|
return current_loop
|
|
|
|
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
# If no event loop exists or it is closed, create a new one
|
|
|
|
|
logger.info("Creating a new event loop in main thread.")
|
|
|
|
|
new_loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(new_loop)
|
|
|
|
|
return new_loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@final
|
|
|
|
|
@dataclass
|
|
|
|
|
class LightRAG:
|
|
|
|
|
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
|
|
|
|
|
|
|
|
|
# Directory
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
working_dir: str = field(
|
|
|
|
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
|
|
|
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
|
|
|
)
|
|
|
|
|
"""Directory where cache and temporary files are stored."""
|
|
|
|
|
|
|
|
|
|
embedding_cache_config: dict[str, Any] = field(
|
|
|
|
|
default_factory=lambda: {
|
|
|
|
|
"enabled": False,
|
|
|
|
|
"similarity_threshold": 0.95,
|
|
|
|
|
"use_llm_check": False,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
"""Configuration for embedding cache.
|
|
|
|
|
- enabled: If True, enables caching to avoid redundant computations.
|
|
|
|
|
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
|
|
|
|
- use_llm_check: If True, validates cached embeddings using an LLM.
|
|
|
|
|
"""
|
|
|
|
|
# Storage
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
kv_storage: str = field(default="JsonKVStorage")
|
|
|
|
|
"""Storage backend for key-value data."""
|
|
|
|
@@ -261,32 +77,74 @@ class LightRAG:
|
|
|
|
|
"""Storage type for tracking document processing statuses."""
|
|
|
|
|
|
|
|
|
|
# Logging
|
|
|
|
|
current_log_level = logger.level
|
|
|
|
|
log_level: int = field(default=current_log_level)
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
log_level: int = field(default=logger.level)
|
|
|
|
|
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
|
|
|
|
|
|
|
|
|
log_dir: str = field(default=os.getcwd())
|
|
|
|
|
"""Directory where logs are stored. Defaults to the current working directory."""
|
|
|
|
|
|
|
|
|
|
# Text chunking
|
|
|
|
|
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
|
|
|
|
|
"""Maximum number of tokens per text chunk when splitting documents."""
|
|
|
|
|
|
|
|
|
|
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
|
|
|
|
|
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
|
|
|
|
|
|
|
|
|
tiktoken_model_name: str = "gpt-4o-mini"
|
|
|
|
|
"""Model name used for tokenization when chunking text."""
|
|
|
|
|
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
|
|
|
|
|
"""Log file path."""
|
|
|
|
|
|
|
|
|
|
# Entity extraction
|
|
|
|
|
entity_extract_max_gleaning: int = 1
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
entity_extract_max_gleaning: int = field(default=1)
|
|
|
|
|
"""Maximum number of entity extraction attempts for ambiguous content."""
|
|
|
|
|
|
|
|
|
|
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
|
|
|
|
|
entity_summary_to_max_tokens: int = field(
|
|
|
|
|
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Text chunking
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
|
|
|
|
|
"""Maximum number of tokens per text chunk when splitting documents."""
|
|
|
|
|
|
|
|
|
|
chunk_overlap_token_size: int = field(
|
|
|
|
|
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
|
|
|
|
|
)
|
|
|
|
|
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
|
|
|
|
|
|
|
|
|
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
|
|
|
|
"""Model name used for tokenization when chunking text."""
|
|
|
|
|
|
|
|
|
|
"""Maximum number of tokens used for summarizing extracted entities."""
|
|
|
|
|
|
|
|
|
|
chunking_func: Callable[
|
|
|
|
|
[
|
|
|
|
|
str,
|
|
|
|
|
str | None,
|
|
|
|
|
bool,
|
|
|
|
|
int,
|
|
|
|
|
int,
|
|
|
|
|
str,
|
|
|
|
|
],
|
|
|
|
|
list[dict[str, Any]],
|
|
|
|
|
] = field(default_factory=lambda: chunking_by_token_size)
|
|
|
|
|
"""
|
|
|
|
|
Custom chunking function for splitting text into chunks before processing.
|
|
|
|
|
|
|
|
|
|
The function should take the following parameters:
|
|
|
|
|
|
|
|
|
|
- `content`: The text to be split into chunks.
|
|
|
|
|
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
|
|
|
|
- `split_by_character_only`: If True, the text is split only on the specified character.
|
|
|
|
|
- `chunk_token_size`: The maximum number of tokens per chunk.
|
|
|
|
|
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
|
|
|
|
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
|
|
|
|
|
|
|
|
|
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
|
|
|
|
- `tokens`: The number of tokens in the chunk.
|
|
|
|
|
- `content`: The text content of the chunk.
|
|
|
|
|
|
|
|
|
|
Defaults to `chunking_by_token_size` if not specified.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Node embedding
|
|
|
|
|
node_embedding_algorithm: str = "node2vec"
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
node_embedding_algorithm: str = field(default="node2vec")
|
|
|
|
|
"""Algorithm used for node embedding in knowledge graphs."""
|
|
|
|
|
|
|
|
|
|
node2vec_params: dict[str, int] = field(
|
|
|
|
@@ -308,119 +166,98 @@ class LightRAG:
|
|
|
|
|
- random_seed: Seed value for reproducibility.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
embedding_func: EmbeddingFunc | None = None
|
|
|
|
|
# Embedding
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
embedding_func: EmbeddingFunc | None = field(default=None)
|
|
|
|
|
"""Function for computing text embeddings. Must be set before use."""
|
|
|
|
|
|
|
|
|
|
embedding_batch_num: int = 32
|
|
|
|
|
embedding_batch_num: int = field(default=32)
|
|
|
|
|
"""Batch size for embedding computations."""
|
|
|
|
|
|
|
|
|
|
embedding_func_max_async: int = 16
|
|
|
|
|
embedding_func_max_async: int = field(default=16)
|
|
|
|
|
"""Maximum number of concurrent embedding function calls."""
|
|
|
|
|
|
|
|
|
|
embedding_cache_config: dict[str, Any] = field(
|
|
|
|
|
default={
|
|
|
|
|
"enabled": False,
|
|
|
|
|
"similarity_threshold": 0.95,
|
|
|
|
|
"use_llm_check": False,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
"""Configuration for embedding cache.
|
|
|
|
|
- enabled: If True, enables caching to avoid redundant computations.
|
|
|
|
|
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
|
|
|
|
- use_llm_check: If True, validates cached embeddings using an LLM.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# LLM Configuration
|
|
|
|
|
llm_model_func: Callable[..., object] | None = None
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
llm_model_func: Callable[..., object] | None = field(default=None)
|
|
|
|
|
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
|
|
|
|
|
|
|
|
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
|
|
|
|
llm_model_name: str = field(default="gpt-4o-mini")
|
|
|
|
|
"""Name of the LLM model used for generating responses."""
|
|
|
|
|
|
|
|
|
|
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
|
|
|
|
|
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
|
|
|
|
|
"""Maximum number of tokens allowed per LLM response."""
|
|
|
|
|
|
|
|
|
|
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
|
|
|
|
|
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
|
|
|
|
|
"""Maximum number of concurrent LLM calls."""
|
|
|
|
|
|
|
|
|
|
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
"""Additional keyword arguments passed to the LLM model function."""
|
|
|
|
|
|
|
|
|
|
# Storage
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
"""Additional parameters for vector database storage."""
|
|
|
|
|
|
|
|
|
|
namespace_prefix: str = field(default="")
|
|
|
|
|
"""Prefix for namespacing stored data across different environments."""
|
|
|
|
|
|
|
|
|
|
enable_llm_cache: bool = True
|
|
|
|
|
enable_llm_cache: bool = field(default=True)
|
|
|
|
|
"""Enables caching for LLM responses to avoid redundant computations."""
|
|
|
|
|
|
|
|
|
|
enable_llm_cache_for_entity_extract: bool = True
|
|
|
|
|
enable_llm_cache_for_entity_extract: bool = field(default=True)
|
|
|
|
|
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
|
|
|
|
|
|
|
|
|
# Extensions
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
|
|
|
|
|
"""Maximum number of parallel insert operations."""
|
|
|
|
|
|
|
|
|
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
# Storages Management
|
|
|
|
|
auto_manage_storages_states: bool = True
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
auto_manage_storages_states: bool = field(default=True)
|
|
|
|
|
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
|
|
|
|
|
|
|
|
|
"""Dictionary for additional parameters and extensions."""
|
|
|
|
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
|
|
|
|
convert_response_to_json
|
|
|
|
|
# Storages Management
|
|
|
|
|
# ---
|
|
|
|
|
|
|
|
|
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
|
|
|
|
|
default_factory=lambda: convert_response_to_json
|
|
|
|
|
)
|
|
|
|
|
"""
|
|
|
|
|
Custom function for converting LLM responses to JSON format.
|
|
|
|
|
|
|
|
|
|
# Lock for entity extraction
|
|
|
|
|
_entity_lock = Lock()
|
|
|
|
|
The default function is :func:`.utils.convert_response_to_json`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Custom Chunking Function
|
|
|
|
|
chunking_func: Callable[
|
|
|
|
|
[
|
|
|
|
|
str,
|
|
|
|
|
str | None,
|
|
|
|
|
bool,
|
|
|
|
|
int,
|
|
|
|
|
int,
|
|
|
|
|
str,
|
|
|
|
|
],
|
|
|
|
|
list[dict[str, Any]],
|
|
|
|
|
] = chunking_by_token_size
|
|
|
|
|
|
|
|
|
|
def verify_storage_implementation(
|
|
|
|
|
self, storage_type: str, storage_name: str
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Verify if storage implementation is compatible with specified storage type
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
|
|
|
|
storage_name: Storage implementation name
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If storage implementation is incompatible or missing required methods
|
|
|
|
|
"""
|
|
|
|
|
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
|
|
|
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
|
|
|
|
|
|
|
|
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
|
|
|
|
if storage_name not in storage_info["implementations"]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
|
|
|
|
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def check_storage_env_vars(self, storage_name: str) -> None:
|
|
|
|
|
"""Check if all required environment variables for storage implementation exist
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
storage_name: Storage implementation name
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If required environment variables are missing
|
|
|
|
|
"""
|
|
|
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
|
|
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
|
|
|
|
|
|
|
|
if missing_vars:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Storage implementation '{storage_name}' requires the following "
|
|
|
|
|
f"environment variables: {', '.join(missing_vars)}"
|
|
|
|
|
)
|
|
|
|
|
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
|
|
|
log_file = os.path.join(self.log_dir, "lightrag.log")
|
|
|
|
|
set_logger(log_file)
|
|
|
|
|
|
|
|
|
|
logger.setLevel(self.log_level)
|
|
|
|
|
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
|
|
|
|
set_logger(self.log_file_path)
|
|
|
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(self.working_dir):
|
|
|
|
|
logger.info(f"Creating working directory {self.working_dir}")
|
|
|
|
|
os.makedirs(self.working_dir)
|
|
|
|
@@ -448,9 +285,6 @@ class LightRAG:
|
|
|
|
|
**self.vector_db_storage_cls_kwargs,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Life cycle
|
|
|
|
|
self.storages_status = StoragesStatus.NOT_CREATED
|
|
|
|
|
|
|
|
|
|
# Show config
|
|
|
|
|
global_config = asdict(self)
|
|
|
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
|
|
|
@@ -558,7 +392,7 @@ class LightRAG:
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.storages_status = StoragesStatus.CREATED
|
|
|
|
|
self._storages_status = StoragesStatus.CREATED
|
|
|
|
|
|
|
|
|
|
# Initialize storages
|
|
|
|
|
if self.auto_manage_storages_states:
|
|
|
|
@@ -573,7 +407,7 @@ class LightRAG:
|
|
|
|
|
|
|
|
|
|
async def initialize_storages(self):
|
|
|
|
|
"""Asynchronously initialize the storages"""
|
|
|
|
|
if self.storages_status == StoragesStatus.CREATED:
|
|
|
|
|
if self._storages_status == StoragesStatus.CREATED:
|
|
|
|
|
tasks = []
|
|
|
|
|
|
|
|
|
|
for storage in (
|
|
|
|
@@ -591,12 +425,12 @@ class LightRAG:
|
|
|
|
|
|
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
self.storages_status = StoragesStatus.INITIALIZED
|
|
|
|
|
self._storages_status = StoragesStatus.INITIALIZED
|
|
|
|
|
logger.debug("Initialized Storages")
|
|
|
|
|
|
|
|
|
|
async def finalize_storages(self):
|
|
|
|
|
"""Asynchronously finalize the storages"""
|
|
|
|
|
if self.storages_status == StoragesStatus.INITIALIZED:
|
|
|
|
|
if self._storages_status == StoragesStatus.INITIALIZED:
|
|
|
|
|
tasks = []
|
|
|
|
|
|
|
|
|
|
for storage in (
|
|
|
|
@@ -614,7 +448,7 @@ class LightRAG:
|
|
|
|
|
|
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
self.storages_status = StoragesStatus.FINALIZED
|
|
|
|
|
self._storages_status = StoragesStatus.FINALIZED
|
|
|
|
|
logger.debug("Finalized Storages")
|
|
|
|
|
|
|
|
|
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
|
|
|
@@ -789,10 +623,9 @@ class LightRAG:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 2. split docs into chunks, insert chunks, update doc status
|
|
|
|
|
batch_size = self.addon_params.get("insert_batch_size", 10)
|
|
|
|
|
docs_batches = [
|
|
|
|
|
list(to_process_docs.items())[i : i + batch_size]
|
|
|
|
|
for i in range(0, len(to_process_docs), batch_size)
|
|
|
|
|
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
|
|
|
|
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
|
|
|
@@ -1203,7 +1036,6 @@ class LightRAG:
|
|
|
|
|
# ---------------------
|
|
|
|
|
# STEP 1: Keyword Extraction
|
|
|
|
|
# ---------------------
|
|
|
|
|
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
|
|
|
|
|
hl_keywords, ll_keywords = await extract_keywords_only(
|
|
|
|
|
text=query,
|
|
|
|
|
param=param,
|
|
|
|
@@ -1629,3 +1461,43 @@ class LightRAG:
|
|
|
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def verify_storage_implementation(
|
|
|
|
|
self, storage_type: str, storage_name: str
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Verify if storage implementation is compatible with specified storage type
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
|
|
|
|
storage_name: Storage implementation name
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If storage implementation is incompatible or missing required methods
|
|
|
|
|
"""
|
|
|
|
|
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
|
|
|
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
|
|
|
|
|
|
|
|
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
|
|
|
|
if storage_name not in storage_info["implementations"]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
|
|
|
|
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def check_storage_env_vars(self, storage_name: str) -> None:
|
|
|
|
|
"""Check if all required environment variables for storage implementation exist
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
storage_name: Storage implementation name
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If required environment variables are missing
|
|
|
|
|
"""
|
|
|
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
|
|
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
|
|
|
|
|
|
|
|
if missing_vars:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Storage implementation '{storage_name}' requires the following "
|
|
|
|
|
f"environment variables: {', '.join(missing_vars)}"
|
|
|
|
|
)
|
|
|
|
|