@@ -1,5 +1,3 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
lightrag:
|
lightrag:
|
||||||
build: .
|
build: .
|
||||||
|
@@ -98,7 +98,6 @@ async def init():
|
|||||||
|
|
||||||
# Initialize LightRAG
|
# Initialize LightRAG
|
||||||
# We use Oracle DB as the KV/vector/graph storage
|
# We use Oracle DB as the KV/vector/graph storage
|
||||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
enable_llm_cache=False,
|
enable_llm_cache=False,
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
import os
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.llm import openai_complete, openai_embed
|
from lightrag.llm import openai_complete, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||||
from lightrag.lightrag import always_get_an_event_loop
|
|
||||||
from lightrag import QueryParam
|
from lightrag import QueryParam
|
||||||
|
|
||||||
# WorkingDir
|
# WorkingDir
|
||||||
|
@@ -63,7 +63,6 @@ async def main():
|
|||||||
|
|
||||||
# Initialize LightRAG
|
# Initialize LightRAG
|
||||||
# We use TiDB DB as the KV/vector
|
# We use TiDB DB as the KV/vector
|
||||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
enable_llm_cache=False,
|
enable_llm_cache=False,
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
|
@@ -1 +1,136 @@
|
|||||||
# print ("init package vars here. ......")
|
STORAGE_IMPLEMENTATIONS = {
|
||||||
|
"KV_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"JsonKVStorage",
|
||||||
|
"MongoKVStorage",
|
||||||
|
"RedisKVStorage",
|
||||||
|
"TiDBKVStorage",
|
||||||
|
"PGKVStorage",
|
||||||
|
"OracleKVStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["get_by_id", "upsert"],
|
||||||
|
},
|
||||||
|
"GRAPH_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"NetworkXStorage",
|
||||||
|
"Neo4JStorage",
|
||||||
|
"MongoGraphStorage",
|
||||||
|
"TiDBGraphStorage",
|
||||||
|
"AGEStorage",
|
||||||
|
"GremlinStorage",
|
||||||
|
"PGGraphStorage",
|
||||||
|
"OracleGraphStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["upsert_node", "upsert_edge"],
|
||||||
|
},
|
||||||
|
"VECTOR_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"NanoVectorDBStorage",
|
||||||
|
"MilvusVectorDBStorage",
|
||||||
|
"ChromaVectorDBStorage",
|
||||||
|
"TiDBVectorDBStorage",
|
||||||
|
"PGVectorStorage",
|
||||||
|
"FaissVectorDBStorage",
|
||||||
|
"QdrantVectorDBStorage",
|
||||||
|
"OracleVectorDBStorage",
|
||||||
|
"MongoVectorDBStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["query", "upsert"],
|
||||||
|
},
|
||||||
|
"DOC_STATUS_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"JsonDocStatusStorage",
|
||||||
|
"PGDocStatusStorage",
|
||||||
|
"PGDocStatusStorage",
|
||||||
|
"MongoDocStatusStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["get_docs_by_status"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Storage implementation environment variable without default value
|
||||||
|
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||||
|
# KV Storage Implementations
|
||||||
|
"JsonKVStorage": [],
|
||||||
|
"MongoKVStorage": [],
|
||||||
|
"RedisKVStorage": ["REDIS_URI"],
|
||||||
|
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||||
|
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
|
"OracleKVStorage": [
|
||||||
|
"ORACLE_DSN",
|
||||||
|
"ORACLE_USER",
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
],
|
||||||
|
# Graph Storage Implementations
|
||||||
|
"NetworkXStorage": [],
|
||||||
|
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||||
|
"MongoGraphStorage": [],
|
||||||
|
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||||
|
"AGEStorage": [
|
||||||
|
"AGE_POSTGRES_DB",
|
||||||
|
"AGE_POSTGRES_USER",
|
||||||
|
"AGE_POSTGRES_PASSWORD",
|
||||||
|
],
|
||||||
|
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||||
|
"PGGraphStorage": [
|
||||||
|
"POSTGRES_USER",
|
||||||
|
"POSTGRES_PASSWORD",
|
||||||
|
"POSTGRES_DATABASE",
|
||||||
|
],
|
||||||
|
"OracleGraphStorage": [
|
||||||
|
"ORACLE_DSN",
|
||||||
|
"ORACLE_USER",
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
],
|
||||||
|
# Vector Storage Implementations
|
||||||
|
"NanoVectorDBStorage": [],
|
||||||
|
"MilvusVectorDBStorage": [],
|
||||||
|
"ChromaVectorDBStorage": [],
|
||||||
|
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||||
|
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
|
"FaissVectorDBStorage": [],
|
||||||
|
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||||
|
"OracleVectorDBStorage": [
|
||||||
|
"ORACLE_DSN",
|
||||||
|
"ORACLE_USER",
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
],
|
||||||
|
"MongoVectorDBStorage": [],
|
||||||
|
# Document Status Storage Implementations
|
||||||
|
"JsonDocStatusStorage": [],
|
||||||
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
|
"MongoDocStatusStorage": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Storage implementation module mapping
|
||||||
|
STORAGES = {
|
||||||
|
"NetworkXStorage": ".kg.networkx_impl",
|
||||||
|
"JsonKVStorage": ".kg.json_kv_impl",
|
||||||
|
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||||
|
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||||
|
"Neo4JStorage": ".kg.neo4j_impl",
|
||||||
|
"OracleKVStorage": ".kg.oracle_impl",
|
||||||
|
"OracleGraphStorage": ".kg.oracle_impl",
|
||||||
|
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||||
|
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||||
|
"MongoKVStorage": ".kg.mongo_impl",
|
||||||
|
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||||
|
"MongoGraphStorage": ".kg.mongo_impl",
|
||||||
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||||
|
"RedisKVStorage": ".kg.redis_impl",
|
||||||
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||||
|
"TiDBKVStorage": ".kg.tidb_impl",
|
||||||
|
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||||
|
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||||
|
"PGKVStorage": ".kg.postgres_impl",
|
||||||
|
"PGVectorStorage": ".kg.postgres_impl",
|
||||||
|
"AGEStorage": ".kg.age_impl",
|
||||||
|
"PGGraphStorage": ".kg.postgres_impl",
|
||||||
|
"GremlinStorage": ".kg.gremlin_impl",
|
||||||
|
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||||
|
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||||
|
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||||
|
}
|
||||||
|
@@ -44,7 +44,7 @@ class OracleDB:
|
|||||||
self.increment = 1
|
self.increment = 1
|
||||||
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
||||||
if self.user is None or self.password is None:
|
if self.user is None or self.password is None:
|
||||||
raise ValueError("Missing database user or password in addon_params")
|
raise ValueError("Missing database user or password")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
oracledb.defaults.fetch_lobs = False
|
oracledb.defaults.fetch_lobs = False
|
||||||
|
@@ -54,9 +54,7 @@ class PostgreSQLDB:
|
|||||||
self.pool: Pool | None = None
|
self.pool: Pool | None = None
|
||||||
|
|
||||||
if self.user is None or self.password is None or self.database is None:
|
if self.user is None or self.password is None or self.database is None:
|
||||||
raise ValueError(
|
raise ValueError("Missing database user, password, or database")
|
||||||
"Missing database user, password, or database in addon_params"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def initdb(self):
|
async def initdb(self):
|
||||||
try:
|
try:
|
||||||
|
@@ -6,8 +6,10 @@ import configparser
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
|
||||||
from asyncio import Lock
|
|
||||||
|
from lightrag.kg import STORAGE_ENV_REQUIREMENTS, STORAGE_IMPLEMENTATIONS, STORAGES
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
@@ -32,8 +34,10 @@ from .operate import (
|
|||||||
from .prompt import GRAPH_FIELD_SEP
|
from .prompt import GRAPH_FIELD_SEP
|
||||||
from .utils import (
|
from .utils import (
|
||||||
EmbeddingFunc,
|
EmbeddingFunc,
|
||||||
|
always_get_an_event_loop,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
convert_response_to_json,
|
convert_response_to_json,
|
||||||
|
lazy_external_import,
|
||||||
limit_async_func_call,
|
limit_async_func_call,
|
||||||
logger,
|
logger,
|
||||||
set_logger,
|
set_logger,
|
||||||
@@ -43,210 +47,22 @@ from .utils import (
|
|||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
# Storage type and implementation compatibility validation table
|
|
||||||
STORAGE_IMPLEMENTATIONS = {
|
|
||||||
"KV_STORAGE": {
|
|
||||||
"implementations": [
|
|
||||||
"JsonKVStorage",
|
|
||||||
"MongoKVStorage",
|
|
||||||
"RedisKVStorage",
|
|
||||||
"TiDBKVStorage",
|
|
||||||
"PGKVStorage",
|
|
||||||
"OracleKVStorage",
|
|
||||||
],
|
|
||||||
"required_methods": ["get_by_id", "upsert"],
|
|
||||||
},
|
|
||||||
"GRAPH_STORAGE": {
|
|
||||||
"implementations": [
|
|
||||||
"NetworkXStorage",
|
|
||||||
"Neo4JStorage",
|
|
||||||
"MongoGraphStorage",
|
|
||||||
"TiDBGraphStorage",
|
|
||||||
"AGEStorage",
|
|
||||||
"GremlinStorage",
|
|
||||||
"PGGraphStorage",
|
|
||||||
"OracleGraphStorage",
|
|
||||||
],
|
|
||||||
"required_methods": ["upsert_node", "upsert_edge"],
|
|
||||||
},
|
|
||||||
"VECTOR_STORAGE": {
|
|
||||||
"implementations": [
|
|
||||||
"NanoVectorDBStorage",
|
|
||||||
"MilvusVectorDBStorage",
|
|
||||||
"ChromaVectorDBStorage",
|
|
||||||
"TiDBVectorDBStorage",
|
|
||||||
"PGVectorStorage",
|
|
||||||
"FaissVectorDBStorage",
|
|
||||||
"QdrantVectorDBStorage",
|
|
||||||
"OracleVectorDBStorage",
|
|
||||||
"MongoVectorDBStorage",
|
|
||||||
],
|
|
||||||
"required_methods": ["query", "upsert"],
|
|
||||||
},
|
|
||||||
"DOC_STATUS_STORAGE": {
|
|
||||||
"implementations": [
|
|
||||||
"JsonDocStatusStorage",
|
|
||||||
"PGDocStatusStorage",
|
|
||||||
"PGDocStatusStorage",
|
|
||||||
"MongoDocStatusStorage",
|
|
||||||
],
|
|
||||||
"required_methods": ["get_docs_by_status"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Storage implementation environment variable without default value
|
|
||||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
||||||
# KV Storage Implementations
|
|
||||||
"JsonKVStorage": [],
|
|
||||||
"MongoKVStorage": [],
|
|
||||||
"RedisKVStorage": ["REDIS_URI"],
|
|
||||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
||||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
||||||
"OracleKVStorage": [
|
|
||||||
"ORACLE_DSN",
|
|
||||||
"ORACLE_USER",
|
|
||||||
"ORACLE_PASSWORD",
|
|
||||||
"ORACLE_CONFIG_DIR",
|
|
||||||
],
|
|
||||||
# Graph Storage Implementations
|
|
||||||
"NetworkXStorage": [],
|
|
||||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
|
||||||
"MongoGraphStorage": [],
|
|
||||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
||||||
"AGEStorage": [
|
|
||||||
"AGE_POSTGRES_DB",
|
|
||||||
"AGE_POSTGRES_USER",
|
|
||||||
"AGE_POSTGRES_PASSWORD",
|
|
||||||
],
|
|
||||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
|
||||||
"PGGraphStorage": [
|
|
||||||
"POSTGRES_USER",
|
|
||||||
"POSTGRES_PASSWORD",
|
|
||||||
"POSTGRES_DATABASE",
|
|
||||||
],
|
|
||||||
"OracleGraphStorage": [
|
|
||||||
"ORACLE_DSN",
|
|
||||||
"ORACLE_USER",
|
|
||||||
"ORACLE_PASSWORD",
|
|
||||||
"ORACLE_CONFIG_DIR",
|
|
||||||
],
|
|
||||||
# Vector Storage Implementations
|
|
||||||
"NanoVectorDBStorage": [],
|
|
||||||
"MilvusVectorDBStorage": [],
|
|
||||||
"ChromaVectorDBStorage": [],
|
|
||||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
||||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
||||||
"FaissVectorDBStorage": [],
|
|
||||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
|
||||||
"OracleVectorDBStorage": [
|
|
||||||
"ORACLE_DSN",
|
|
||||||
"ORACLE_USER",
|
|
||||||
"ORACLE_PASSWORD",
|
|
||||||
"ORACLE_CONFIG_DIR",
|
|
||||||
],
|
|
||||||
"MongoVectorDBStorage": [],
|
|
||||||
# Document Status Storage Implementations
|
|
||||||
"JsonDocStatusStorage": [],
|
|
||||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
||||||
"MongoDocStatusStorage": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Storage implementation module mapping
|
|
||||||
STORAGES = {
|
|
||||||
"NetworkXStorage": ".kg.networkx_impl",
|
|
||||||
"JsonKVStorage": ".kg.json_kv_impl",
|
|
||||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
|
||||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
|
||||||
"Neo4JStorage": ".kg.neo4j_impl",
|
|
||||||
"OracleKVStorage": ".kg.oracle_impl",
|
|
||||||
"OracleGraphStorage": ".kg.oracle_impl",
|
|
||||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
|
||||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
|
||||||
"MongoKVStorage": ".kg.mongo_impl",
|
|
||||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
|
||||||
"MongoGraphStorage": ".kg.mongo_impl",
|
|
||||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
|
||||||
"RedisKVStorage": ".kg.redis_impl",
|
|
||||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
|
||||||
"TiDBKVStorage": ".kg.tidb_impl",
|
|
||||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
|
||||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
|
||||||
"PGKVStorage": ".kg.postgres_impl",
|
|
||||||
"PGVectorStorage": ".kg.postgres_impl",
|
|
||||||
"AGEStorage": ".kg.age_impl",
|
|
||||||
"PGGraphStorage": ".kg.postgres_impl",
|
|
||||||
"GremlinStorage": ".kg.gremlin_impl",
|
|
||||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
|
||||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
|
||||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
|
||||||
"""Lazily import a class from an external module based on the package of the caller."""
|
|
||||||
# Get the caller's module and package
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
caller_frame = inspect.currentframe().f_back
|
|
||||||
module = inspect.getmodule(caller_frame)
|
|
||||||
package = module.__package__ if module else None
|
|
||||||
|
|
||||||
def import_class(*args: Any, **kwargs: Any):
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
module = importlib.import_module(module_name, package=package)
|
|
||||||
cls = getattr(module, class_name)
|
|
||||||
return cls(*args, **kwargs)
|
|
||||||
|
|
||||||
return import_class
|
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
"""
|
|
||||||
Ensure that there is always an event loop available.
|
|
||||||
|
|
||||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
||||||
it creates a new event loop and sets it as the current event loop.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to get the current event loop
|
|
||||||
current_loop = asyncio.get_event_loop()
|
|
||||||
if current_loop.is_closed():
|
|
||||||
raise RuntimeError("Event loop is closed.")
|
|
||||||
return current_loop
|
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
# If no event loop exists or it is closed, create a new one
|
|
||||||
logger.info("Creating a new event loop in main thread.")
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
return new_loop
|
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class LightRAG:
|
class LightRAG:
|
||||||
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
||||||
|
|
||||||
|
# Directory
|
||||||
|
# ---
|
||||||
|
|
||||||
working_dir: str = field(
|
working_dir: str = field(
|
||||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||||
)
|
)
|
||||||
"""Directory where cache and temporary files are stored."""
|
"""Directory where cache and temporary files are stored."""
|
||||||
|
|
||||||
embedding_cache_config: dict[str, Any] = field(
|
# Storage
|
||||||
default_factory=lambda: {
|
# ---
|
||||||
"enabled": False,
|
|
||||||
"similarity_threshold": 0.95,
|
|
||||||
"use_llm_check": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
"""Configuration for embedding cache.
|
|
||||||
- enabled: If True, enables caching to avoid redundant computations.
|
|
||||||
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
|
||||||
- use_llm_check: If True, validates cached embeddings using an LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
kv_storage: str = field(default="JsonKVStorage")
|
kv_storage: str = field(default="JsonKVStorage")
|
||||||
"""Storage backend for key-value data."""
|
"""Storage backend for key-value data."""
|
||||||
@@ -261,32 +77,74 @@ class LightRAG:
|
|||||||
"""Storage type for tracking document processing statuses."""
|
"""Storage type for tracking document processing statuses."""
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
current_log_level = logger.level
|
# ---
|
||||||
log_level: int = field(default=current_log_level)
|
|
||||||
|
log_level: int = field(default=logger.level)
|
||||||
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
||||||
|
|
||||||
log_dir: str = field(default=os.getcwd())
|
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
|
||||||
"""Directory where logs are stored. Defaults to the current working directory."""
|
"""Log file path."""
|
||||||
|
|
||||||
# Text chunking
|
|
||||||
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
|
|
||||||
"""Maximum number of tokens per text chunk when splitting documents."""
|
|
||||||
|
|
||||||
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
|
|
||||||
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
|
||||||
|
|
||||||
tiktoken_model_name: str = "gpt-4o-mini"
|
|
||||||
"""Model name used for tokenization when chunking text."""
|
|
||||||
|
|
||||||
# Entity extraction
|
# Entity extraction
|
||||||
entity_extract_max_gleaning: int = 1
|
# ---
|
||||||
|
|
||||||
|
entity_extract_max_gleaning: int = field(default=1)
|
||||||
"""Maximum number of entity extraction attempts for ambiguous content."""
|
"""Maximum number of entity extraction attempts for ambiguous content."""
|
||||||
|
|
||||||
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
|
entity_summary_to_max_tokens: int = field(
|
||||||
|
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Text chunking
|
||||||
|
# ---
|
||||||
|
|
||||||
|
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
|
||||||
|
"""Maximum number of tokens per text chunk when splitting documents."""
|
||||||
|
|
||||||
|
chunk_overlap_token_size: int = field(
|
||||||
|
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
|
||||||
|
)
|
||||||
|
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||||
|
|
||||||
|
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
||||||
|
"""Model name used for tokenization when chunking text."""
|
||||||
|
|
||||||
"""Maximum number of tokens used for summarizing extracted entities."""
|
"""Maximum number of tokens used for summarizing extracted entities."""
|
||||||
|
|
||||||
|
chunking_func: Callable[
|
||||||
|
[
|
||||||
|
str,
|
||||||
|
str | None,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
str,
|
||||||
|
],
|
||||||
|
list[dict[str, Any]],
|
||||||
|
] = field(default_factory=lambda: chunking_by_token_size)
|
||||||
|
"""
|
||||||
|
Custom chunking function for splitting text into chunks before processing.
|
||||||
|
|
||||||
|
The function should take the following parameters:
|
||||||
|
|
||||||
|
- `content`: The text to be split into chunks.
|
||||||
|
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
||||||
|
- `split_by_character_only`: If True, the text is split only on the specified character.
|
||||||
|
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||||
|
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||||
|
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
||||||
|
|
||||||
|
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
||||||
|
- `tokens`: The number of tokens in the chunk.
|
||||||
|
- `content`: The text content of the chunk.
|
||||||
|
|
||||||
|
Defaults to `chunking_by_token_size` if not specified.
|
||||||
|
"""
|
||||||
|
|
||||||
# Node embedding
|
# Node embedding
|
||||||
node_embedding_algorithm: str = "node2vec"
|
# ---
|
||||||
|
|
||||||
|
node_embedding_algorithm: str = field(default="node2vec")
|
||||||
"""Algorithm used for node embedding in knowledge graphs."""
|
"""Algorithm used for node embedding in knowledge graphs."""
|
||||||
|
|
||||||
node2vec_params: dict[str, int] = field(
|
node2vec_params: dict[str, int] = field(
|
||||||
@@ -308,119 +166,98 @@ class LightRAG:
|
|||||||
- random_seed: Seed value for reproducibility.
|
- random_seed: Seed value for reproducibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding_func: EmbeddingFunc | None = None
|
# Embedding
|
||||||
|
# ---
|
||||||
|
|
||||||
|
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||||
"""Function for computing text embeddings. Must be set before use."""
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = field(default=32)
|
||||||
"""Batch size for embedding computations."""
|
"""Batch size for embedding computations."""
|
||||||
|
|
||||||
embedding_func_max_async: int = 16
|
embedding_func_max_async: int = field(default=16)
|
||||||
"""Maximum number of concurrent embedding function calls."""
|
"""Maximum number of concurrent embedding function calls."""
|
||||||
|
|
||||||
|
embedding_cache_config: dict[str, Any] = field(
|
||||||
|
default={
|
||||||
|
"enabled": False,
|
||||||
|
"similarity_threshold": 0.95,
|
||||||
|
"use_llm_check": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
"""Configuration for embedding cache.
|
||||||
|
- enabled: If True, enables caching to avoid redundant computations.
|
||||||
|
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
||||||
|
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
# LLM Configuration
|
# LLM Configuration
|
||||||
llm_model_func: Callable[..., object] | None = None
|
# ---
|
||||||
|
|
||||||
|
llm_model_func: Callable[..., object] | None = field(default=None)
|
||||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||||
|
|
||||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
llm_model_name: str = field(default="gpt-4o-mini")
|
||||||
"""Name of the LLM model used for generating responses."""
|
"""Name of the LLM model used for generating responses."""
|
||||||
|
|
||||||
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
|
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
|
||||||
"""Maximum number of tokens allowed per LLM response."""
|
"""Maximum number of tokens allowed per LLM response."""
|
||||||
|
|
||||||
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
|
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
|
||||||
"""Maximum number of concurrent LLM calls."""
|
"""Maximum number of concurrent LLM calls."""
|
||||||
|
|
||||||
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Additional keyword arguments passed to the LLM model function."""
|
"""Additional keyword arguments passed to the LLM model function."""
|
||||||
|
|
||||||
# Storage
|
# Storage
|
||||||
|
# ---
|
||||||
|
|
||||||
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Additional parameters for vector database storage."""
|
"""Additional parameters for vector database storage."""
|
||||||
|
|
||||||
namespace_prefix: str = field(default="")
|
namespace_prefix: str = field(default="")
|
||||||
"""Prefix for namespacing stored data across different environments."""
|
"""Prefix for namespacing stored data across different environments."""
|
||||||
|
|
||||||
enable_llm_cache: bool = True
|
enable_llm_cache: bool = field(default=True)
|
||||||
"""Enables caching for LLM responses to avoid redundant computations."""
|
"""Enables caching for LLM responses to avoid redundant computations."""
|
||||||
|
|
||||||
enable_llm_cache_for_entity_extract: bool = True
|
enable_llm_cache_for_entity_extract: bool = field(default=True)
|
||||||
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
|
# ---
|
||||||
|
|
||||||
|
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
|
||||||
|
"""Maximum number of parallel insert operations."""
|
||||||
|
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
# Storages Management
|
# Storages Management
|
||||||
auto_manage_storages_states: bool = True
|
# ---
|
||||||
|
|
||||||
|
auto_manage_storages_states: bool = field(default=True)
|
||||||
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
||||||
|
|
||||||
"""Dictionary for additional parameters and extensions."""
|
# Storages Management
|
||||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
# ---
|
||||||
convert_response_to_json
|
|
||||||
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
|
||||||
|
default_factory=lambda: convert_response_to_json
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
Custom function for converting LLM responses to JSON format.
|
||||||
|
|
||||||
# Lock for entity extraction
|
The default function is :func:`.utils.convert_response_to_json`.
|
||||||
_entity_lock = Lock()
|
"""
|
||||||
|
|
||||||
# Custom Chunking Function
|
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||||
chunking_func: Callable[
|
|
||||||
[
|
|
||||||
str,
|
|
||||||
str | None,
|
|
||||||
bool,
|
|
||||||
int,
|
|
||||||
int,
|
|
||||||
str,
|
|
||||||
],
|
|
||||||
list[dict[str, Any]],
|
|
||||||
] = chunking_by_token_size
|
|
||||||
|
|
||||||
def verify_storage_implementation(
|
|
||||||
self, storage_type: str, storage_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Verify if storage implementation is compatible with specified storage type
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
|
||||||
storage_name: Storage implementation name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If storage implementation is incompatible or missing required methods
|
|
||||||
"""
|
|
||||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
|
||||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
||||||
|
|
||||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
|
||||||
if storage_name not in storage_info["implementations"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
|
||||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
|
||||||
"""Check if all required environment variables for storage implementation exist
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_name: Storage implementation name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If required environment variables are missing
|
|
||||||
"""
|
|
||||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
||||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
||||||
|
|
||||||
if missing_vars:
|
|
||||||
raise ValueError(
|
|
||||||
f"Storage implementation '{storage_name}' requires the following "
|
|
||||||
f"environment variables: {', '.join(missing_vars)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
os.makedirs(self.log_dir, exist_ok=True)
|
|
||||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
|
||||||
set_logger(log_file)
|
|
||||||
|
|
||||||
logger.setLevel(self.log_level)
|
logger.setLevel(self.log_level)
|
||||||
|
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
||||||
|
set_logger(self.log_file_path)
|
||||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||||
|
|
||||||
if not os.path.exists(self.working_dir):
|
if not os.path.exists(self.working_dir):
|
||||||
logger.info(f"Creating working directory {self.working_dir}")
|
logger.info(f"Creating working directory {self.working_dir}")
|
||||||
os.makedirs(self.working_dir)
|
os.makedirs(self.working_dir)
|
||||||
@@ -448,9 +285,6 @@ class LightRAG:
|
|||||||
**self.vector_db_storage_cls_kwargs,
|
**self.vector_db_storage_cls_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Life cycle
|
|
||||||
self.storages_status = StoragesStatus.NOT_CREATED
|
|
||||||
|
|
||||||
# Show config
|
# Show config
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||||
@@ -558,7 +392,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.storages_status = StoragesStatus.CREATED
|
self._storages_status = StoragesStatus.CREATED
|
||||||
|
|
||||||
# Initialize storages
|
# Initialize storages
|
||||||
if self.auto_manage_storages_states:
|
if self.auto_manage_storages_states:
|
||||||
@@ -573,7 +407,7 @@ class LightRAG:
|
|||||||
|
|
||||||
async def initialize_storages(self):
|
async def initialize_storages(self):
|
||||||
"""Asynchronously initialize the storages"""
|
"""Asynchronously initialize the storages"""
|
||||||
if self.storages_status == StoragesStatus.CREATED:
|
if self._storages_status == StoragesStatus.CREATED:
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
for storage in (
|
for storage in (
|
||||||
@@ -591,12 +425,12 @@ class LightRAG:
|
|||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
self.storages_status = StoragesStatus.INITIALIZED
|
self._storages_status = StoragesStatus.INITIALIZED
|
||||||
logger.debug("Initialized Storages")
|
logger.debug("Initialized Storages")
|
||||||
|
|
||||||
async def finalize_storages(self):
|
async def finalize_storages(self):
|
||||||
"""Asynchronously finalize the storages"""
|
"""Asynchronously finalize the storages"""
|
||||||
if self.storages_status == StoragesStatus.INITIALIZED:
|
if self._storages_status == StoragesStatus.INITIALIZED:
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
for storage in (
|
for storage in (
|
||||||
@@ -614,7 +448,7 @@ class LightRAG:
|
|||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
self.storages_status = StoragesStatus.FINALIZED
|
self._storages_status = StoragesStatus.FINALIZED
|
||||||
logger.debug("Finalized Storages")
|
logger.debug("Finalized Storages")
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
@@ -789,10 +623,9 @@ class LightRAG:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 2. split docs into chunks, insert chunks, update doc status
|
# 2. split docs into chunks, insert chunks, update doc status
|
||||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
|
||||||
docs_batches = [
|
docs_batches = [
|
||||||
list(to_process_docs.items())[i : i + batch_size]
|
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
||||||
for i in range(0, len(to_process_docs), batch_size)
|
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
||||||
@@ -1203,7 +1036,6 @@ class LightRAG:
|
|||||||
# ---------------------
|
# ---------------------
|
||||||
# STEP 1: Keyword Extraction
|
# STEP 1: Keyword Extraction
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
|
|
||||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||||
text=query,
|
text=query,
|
||||||
param=param,
|
param=param,
|
||||||
@@ -1629,3 +1461,43 @@ class LightRAG:
|
|||||||
result["vector_data"] = vector_data[0] if vector_data else None
|
result["vector_data"] = vector_data[0] if vector_data else None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def verify_storage_implementation(
|
||||||
|
self, storage_type: str, storage_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Verify if storage implementation is compatible with specified storage type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||||
|
storage_name: Storage implementation name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If storage implementation is incompatible or missing required methods
|
||||||
|
"""
|
||||||
|
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||||
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||||
|
|
||||||
|
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||||
|
if storage_name not in storage_info["implementations"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||||
|
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||||
|
"""Check if all required environment variables for storage implementation exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_name: Storage implementation name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required environment variables are missing
|
||||||
|
"""
|
||||||
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||||
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||||
|
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Storage implementation '{storage_name}' requires the following "
|
||||||
|
f"environment variables: {', '.join(missing_vars)}"
|
||||||
|
)
|
||||||
|
@@ -713,3 +713,47 @@ def get_conversation_turns(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(formatted_turns)
|
return "\n".join(formatted_turns)
|
||||||
|
|
||||||
|
|
||||||
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
|
"""
|
||||||
|
Ensure that there is always an event loop available.
|
||||||
|
|
||||||
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||||
|
it creates a new event loop and sets it as the current event loop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to get the current event loop
|
||||||
|
current_loop = asyncio.get_event_loop()
|
||||||
|
if current_loop.is_closed():
|
||||||
|
raise RuntimeError("Event loop is closed.")
|
||||||
|
return current_loop
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
# If no event loop exists or it is closed, create a new one
|
||||||
|
logger.info("Creating a new event loop in main thread.")
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
return new_loop
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||||
|
"""Lazily import a class from an external module based on the package of the caller."""
|
||||||
|
# Get the caller's module and package
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
caller_frame = inspect.currentframe().f_back
|
||||||
|
module = inspect.getmodule(caller_frame)
|
||||||
|
package = module.__package__ if module else None
|
||||||
|
|
||||||
|
def import_class(*args: Any, **kwargs: Any):
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
module = importlib.import_module(module_name, package=package)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
|
return import_class
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.utils import always_get_an_event_loop
|
||||||
|
|
||||||
|
|
||||||
def extract_queries(file_path):
|
def extract_queries(file_path):
|
||||||
@@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop
|
|
||||||
|
|
||||||
|
|
||||||
def run_queries_and_save_to_json(
|
def run_queries_and_save_to_json(
|
||||||
queries, rag_instance, query_param, output_file, error_file
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
):
|
):
|
||||||
|
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop
|
|
||||||
|
|
||||||
|
|
||||||
def run_queries_and_save_to_json(
|
def run_queries_and_save_to_json(
|
||||||
queries, rag_instance, query_param, output_file, error_file
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
):
|
):
|
||||||
|
Reference in New Issue
Block a user