@@ -1,5 +1,3 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
lightrag:
|
||||
build: .
|
||||
|
@@ -98,7 +98,6 @@ async def init():
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use Oracle DB as the KV/vector/graph storage
|
||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
||||
rag = LightRAG(
|
||||
enable_llm_cache=False,
|
||||
working_dir=WORKING_DIR,
|
||||
|
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
import inspect
|
||||
import os
|
||||
from lightrag import LightRAG
|
||||
from lightrag.llm import openai_complete, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.lightrag import always_get_an_event_loop
|
||||
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||
from lightrag import QueryParam
|
||||
|
||||
# WorkingDir
|
||||
|
@@ -63,7 +63,6 @@ async def main():
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use TiDB DB as the KV/vector
|
||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
||||
rag = LightRAG(
|
||||
enable_llm_cache=False,
|
||||
working_dir=WORKING_DIR,
|
||||
|
@@ -1 +1,136 @@
|
||||
# print ("init package vars here. ......")
|
||||
STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
"GRAPH_STORAGE": {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
"VECTOR_STORAGE": {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
},
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [],
|
||||
}
|
||||
|
||||
# Storage implementation module mapping
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
}
|
||||
|
@@ -44,7 +44,7 @@ class OracleDB:
|
||||
self.increment = 1
|
||||
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
||||
if self.user is None or self.password is None:
|
||||
raise ValueError("Missing database user or password in addon_params")
|
||||
raise ValueError("Missing database user or password")
|
||||
|
||||
try:
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
@@ -54,9 +54,7 @@ class PostgreSQLDB:
|
||||
self.pool: Pool | None = None
|
||||
|
||||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError(
|
||||
"Missing database user, password, or database in addon_params"
|
||||
)
|
||||
raise ValueError("Missing database user, password, or database")
|
||||
|
||||
async def initdb(self):
|
||||
try:
|
||||
|
@@ -6,8 +6,10 @@ import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||
from asyncio import Lock
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
|
||||
|
||||
from lightrag.kg import STORAGE_ENV_REQUIREMENTS, STORAGE_IMPLEMENTATIONS, STORAGES
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
@@ -32,8 +34,10 @@ from .operate import (
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
always_get_an_event_loop,
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
logger,
|
||||
set_logger,
|
||||
@@ -43,210 +47,22 @@ from .utils import (
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
# Storage type and implementation compatibility validation table
|
||||
STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
"GRAPH_STORAGE": {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
"VECTOR_STORAGE": {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
},
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [],
|
||||
}
|
||||
|
||||
# Storage implementation module mapping
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
}
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
|
||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||
it creates a new event loop and sets it as the current event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||
"""
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
except RuntimeError:
|
||||
# If no event loop exists or it is closed, create a new one
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class LightRAG:
|
||||
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
||||
|
||||
# Directory
|
||||
# ---
|
||||
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
"""Directory where cache and temporary files are stored."""
|
||||
|
||||
embedding_cache_config: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
}
|
||||
)
|
||||
"""Configuration for embedding cache.
|
||||
- enabled: If True, enables caching to avoid redundant computations.
|
||||
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
||||
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||
"""
|
||||
# Storage
|
||||
# ---
|
||||
|
||||
kv_storage: str = field(default="JsonKVStorage")
|
||||
"""Storage backend for key-value data."""
|
||||
@@ -261,32 +77,74 @@ class LightRAG:
|
||||
"""Storage type for tracking document processing statuses."""
|
||||
|
||||
# Logging
|
||||
current_log_level = logger.level
|
||||
log_level: int = field(default=current_log_level)
|
||||
# ---
|
||||
|
||||
log_level: int = field(default=logger.level)
|
||||
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
||||
|
||||
log_dir: str = field(default=os.getcwd())
|
||||
"""Directory where logs are stored. Defaults to the current working directory."""
|
||||
|
||||
# Text chunking
|
||||
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
|
||||
"""Maximum number of tokens per text chunk when splitting documents."""
|
||||
|
||||
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
|
||||
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||
|
||||
tiktoken_model_name: str = "gpt-4o-mini"
|
||||
"""Model name used for tokenization when chunking text."""
|
||||
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
|
||||
"""Log file path."""
|
||||
|
||||
# Entity extraction
|
||||
entity_extract_max_gleaning: int = 1
|
||||
# ---
|
||||
|
||||
entity_extract_max_gleaning: int = field(default=1)
|
||||
"""Maximum number of entity extraction attempts for ambiguous content."""
|
||||
|
||||
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
|
||||
entity_summary_to_max_tokens: int = field(
|
||||
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
|
||||
)
|
||||
|
||||
# Text chunking
|
||||
# ---
|
||||
|
||||
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
|
||||
"""Maximum number of tokens per text chunk when splitting documents."""
|
||||
|
||||
chunk_overlap_token_size: int = field(
|
||||
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
|
||||
)
|
||||
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||
|
||||
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
||||
"""Model name used for tokenization when chunking text."""
|
||||
|
||||
"""Maximum number of tokens used for summarizing extracted entities."""
|
||||
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
str,
|
||||
],
|
||||
list[dict[str, Any]],
|
||||
] = field(default_factory=lambda: chunking_by_token_size)
|
||||
"""
|
||||
Custom chunking function for splitting text into chunks before processing.
|
||||
|
||||
The function should take the following parameters:
|
||||
|
||||
- `content`: The text to be split into chunks.
|
||||
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
||||
- `split_by_character_only`: If True, the text is split only on the specified character.
|
||||
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
||||
|
||||
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
||||
- `tokens`: The number of tokens in the chunk.
|
||||
- `content`: The text content of the chunk.
|
||||
|
||||
Defaults to `chunking_by_token_size` if not specified.
|
||||
"""
|
||||
|
||||
# Node embedding
|
||||
node_embedding_algorithm: str = "node2vec"
|
||||
# ---
|
||||
|
||||
node_embedding_algorithm: str = field(default="node2vec")
|
||||
"""Algorithm used for node embedding in knowledge graphs."""
|
||||
|
||||
node2vec_params: dict[str, int] = field(
|
||||
@@ -308,119 +166,98 @@ class LightRAG:
|
||||
- random_seed: Seed value for reproducibility.
|
||||
"""
|
||||
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
# Embedding
|
||||
# ---
|
||||
|
||||
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = 32
|
||||
embedding_batch_num: int = field(default=32)
|
||||
"""Batch size for embedding computations."""
|
||||
|
||||
embedding_func_max_async: int = 16
|
||||
embedding_func_max_async: int = field(default=16)
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
embedding_cache_config: dict[str, Any] = field(
|
||||
default={
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
}
|
||||
)
|
||||
"""Configuration for embedding cache.
|
||||
- enabled: If True, enables caching to avoid redundant computations.
|
||||
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
||||
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||
"""
|
||||
|
||||
# LLM Configuration
|
||||
llm_model_func: Callable[..., object] | None = None
|
||||
# ---
|
||||
|
||||
llm_model_func: Callable[..., object] | None = field(default=None)
|
||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
llm_model_name: str = field(default="gpt-4o-mini")
|
||||
"""Name of the LLM model used for generating responses."""
|
||||
|
||||
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
|
||||
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
|
||||
"""Maximum number of tokens allowed per LLM response."""
|
||||
|
||||
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
|
||||
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
|
||||
"""Maximum number of concurrent LLM calls."""
|
||||
|
||||
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional keyword arguments passed to the LLM model function."""
|
||||
|
||||
# Storage
|
||||
# ---
|
||||
|
||||
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional parameters for vector database storage."""
|
||||
|
||||
namespace_prefix: str = field(default="")
|
||||
"""Prefix for namespacing stored data across different environments."""
|
||||
|
||||
enable_llm_cache: bool = True
|
||||
enable_llm_cache: bool = field(default=True)
|
||||
"""Enables caching for LLM responses to avoid redundant computations."""
|
||||
|
||||
enable_llm_cache_for_entity_extract: bool = True
|
||||
enable_llm_cache_for_entity_extract: bool = field(default=True)
|
||||
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
||||
|
||||
# Extensions
|
||||
# ---
|
||||
|
||||
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
|
||||
"""Maximum number of parallel insert operations."""
|
||||
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Storages Management
|
||||
auto_manage_storages_states: bool = True
|
||||
# ---
|
||||
|
||||
auto_manage_storages_states: bool = field(default=True)
|
||||
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
||||
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||
convert_response_to_json
|
||||
# Storages Management
|
||||
# ---
|
||||
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
|
||||
default_factory=lambda: convert_response_to_json
|
||||
)
|
||||
"""
|
||||
Custom function for converting LLM responses to JSON format.
|
||||
|
||||
# Lock for entity extraction
|
||||
_entity_lock = Lock()
|
||||
The default function is :func:`.utils.convert_response_to_json`.
|
||||
"""
|
||||
|
||||
# Custom Chunking Function
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
str,
|
||||
],
|
||||
list[dict[str, Any]],
|
||||
] = chunking_by_token_size
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
|
||||
logger.setLevel(self.log_level)
|
||||
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
||||
set_logger(self.log_file_path)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
@@ -448,9 +285,6 @@ class LightRAG:
|
||||
**self.vector_db_storage_cls_kwargs,
|
||||
}
|
||||
|
||||
# Life cycle
|
||||
self.storages_status = StoragesStatus.NOT_CREATED
|
||||
|
||||
# Show config
|
||||
global_config = asdict(self)
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||
@@ -558,7 +392,7 @@ class LightRAG:
|
||||
)
|
||||
)
|
||||
|
||||
self.storages_status = StoragesStatus.CREATED
|
||||
self._storages_status = StoragesStatus.CREATED
|
||||
|
||||
# Initialize storages
|
||||
if self.auto_manage_storages_states:
|
||||
@@ -573,7 +407,7 @@ class LightRAG:
|
||||
|
||||
async def initialize_storages(self):
|
||||
"""Asynchronously initialize the storages"""
|
||||
if self.storages_status == StoragesStatus.CREATED:
|
||||
if self._storages_status == StoragesStatus.CREATED:
|
||||
tasks = []
|
||||
|
||||
for storage in (
|
||||
@@ -591,12 +425,12 @@ class LightRAG:
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.storages_status = StoragesStatus.INITIALIZED
|
||||
self._storages_status = StoragesStatus.INITIALIZED
|
||||
logger.debug("Initialized Storages")
|
||||
|
||||
async def finalize_storages(self):
|
||||
"""Asynchronously finalize the storages"""
|
||||
if self.storages_status == StoragesStatus.INITIALIZED:
|
||||
if self._storages_status == StoragesStatus.INITIALIZED:
|
||||
tasks = []
|
||||
|
||||
for storage in (
|
||||
@@ -614,7 +448,7 @@ class LightRAG:
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.storages_status = StoragesStatus.FINALIZED
|
||||
self._storages_status = StoragesStatus.FINALIZED
|
||||
logger.debug("Finalized Storages")
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
@@ -789,10 +623,9 @@ class LightRAG:
|
||||
return
|
||||
|
||||
# 2. split docs into chunks, insert chunks, update doc status
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
docs_batches = [
|
||||
list(to_process_docs.items())[i : i + batch_size]
|
||||
for i in range(0, len(to_process_docs), batch_size)
|
||||
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
||||
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
||||
]
|
||||
|
||||
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
||||
@@ -1203,7 +1036,6 @@ class LightRAG:
|
||||
# ---------------------
|
||||
# STEP 1: Keyword Extraction
|
||||
# ---------------------
|
||||
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
@@ -1629,3 +1461,43 @@ class LightRAG:
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
|
||||
return result
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
@@ -713,3 +713,47 @@ def get_conversation_turns(
|
||||
)
|
||||
|
||||
return "\n".join(formatted_turns)
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
|
||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||
it creates a new event loop and sets it as the current event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||
"""
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
except RuntimeError:
|
||||
# If no event loop exists or it is closed, create a new one
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import always_get_an_event_loop
|
||||
|
||||
|
||||
def extract_queries(file_path):
|
||||
@@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def run_queries_and_save_to_json(
|
||||
queries, rag_instance, query_param, output_file, error_file
|
||||
):
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param):
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def run_queries_and_save_to_json(
|
||||
queries, rag_instance, query_param, output_file, error_file
|
||||
):
|
||||
|
Reference in New Issue
Block a user