Fix linting

This commit is contained in:
yangdx
2025-02-26 18:11:16 +08:00
parent 7d12715f09
commit 7436c06f6c
11 changed files with 205 additions and 144 deletions

View File

@@ -141,4 +141,4 @@ QDRANT_URL=http://localhost:16333
# QDRANT_API_KEY=your-api-key # QDRANT_API_KEY=your-api-key
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379

View File

@@ -54,11 +54,12 @@ config.read("config.ini")
class LightragPathFilter(logging.Filter): class LightragPathFilter(logging.Filter):
"""Filter for lightrag logger to filter out frequent path access logs""" """Filter for lightrag logger to filter out frequent path access logs"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# Define paths to be filtered # Define paths to be filtered
self.filtered_paths = ["/documents", "/health", "/webui/"] self.filtered_paths = ["/documents", "/health", "/webui/"]
def filter(self, record): def filter(self, record):
try: try:
# Check if record has the required attributes for an access log # Check if record has the required attributes for an access log
@@ -90,11 +91,13 @@ def create_app(args):
# Initialize verbose debug setting # Initialize verbose debug setting
# Can not use the logger at the top of this module when workers > 1 # Can not use the logger at the top of this module when workers > 1
from lightrag.utils import set_verbose_debug, logger from lightrag.utils import set_verbose_debug, logger
# Setup logging # Setup logging
logger.setLevel(getattr(logging, args.log_level)) logger.setLevel(getattr(logging, args.log_level))
set_verbose_debug(args.verbose) set_verbose_debug(args.verbose)
from lightrag.kg.shared_storage import is_multiprocess from lightrag.kg.shared_storage import is_multiprocess
logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") logger.info(f"==== Multi-processor mode: {is_multiprocess} ====")
# Verify that bindings are correctly setup # Verify that bindings are correctly setup
@@ -147,9 +150,7 @@ def create_app(args):
# Auto scan documents if enabled # Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
# Create background task # Create background task
task = asyncio.create_task( task = asyncio.create_task(run_scanning_process(rag, doc_manager))
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task) app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard) task.add_done_callback(app.state.background_tasks.discard)
@@ -411,17 +412,19 @@ def get_application():
"""Factory function for creating the FastAPI application""" """Factory function for creating the FastAPI application"""
# Configure logging for this worker process # Configure logging for this worker process
configure_logging() configure_logging()
# Get args from environment variable # Get args from environment variable
args_json = os.environ.get('LIGHTRAG_ARGS') args_json = os.environ.get("LIGHTRAG_ARGS")
if not args_json: if not args_json:
args = parse_args() # Fallback to parsing args if env var not set args = parse_args() # Fallback to parsing args if env var not set
else: else:
import types import types
args = types.SimpleNamespace(**json.loads(args_json)) args = types.SimpleNamespace(**json.loads(args_json))
if args.workers > 1: if args.workers > 1:
from lightrag.kg.shared_storage import initialize_share_data from lightrag.kg.shared_storage import initialize_share_data
initialize_share_data() initialize_share_data()
return create_app(args) return create_app(args)
@@ -434,58 +437,61 @@ def configure_logging():
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.handlers = [] logger.handlers = []
logger.filters = [] logger.filters = []
# Configure basic logging # Configure basic logging
logging.config.dictConfig({ logging.config.dictConfig(
"version": 1, {
"disable_existing_loggers": False, "version": 1,
"formatters": { "disable_existing_loggers": False,
"default": { "formatters": {
"format": "%(levelname)s: %(message)s", "default": {
"format": "%(levelname)s: %(message)s",
},
}, },
}, "handlers": {
"handlers": { "default": {
"default": { "formatter": "default",
"formatter": "default", "class": "logging.StreamHandler",
"class": "logging.StreamHandler", "stream": "ext://sys.stderr",
"stream": "ext://sys.stderr", },
}, },
}, "loggers": {
"loggers": { "uvicorn.access": {
"uvicorn.access": { "handlers": ["default"],
"handlers": ["default"], "level": "INFO",
"level": "INFO", "propagate": False,
"propagate": False, "filters": ["path_filter"],
"filters": ["path_filter"], },
"lightrag": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
}, },
"lightrag": { "filters": {
"handlers": ["default"], "path_filter": {
"level": "INFO", "()": "lightrag.api.lightrag_server.LightragPathFilter",
"propagate": False, },
"filters": ["path_filter"],
}, },
}, }
"filters": { )
"path_filter": {
"()": "lightrag.api.lightrag_server.LightragPathFilter",
},
},
})
def main(): def main():
from multiprocessing import freeze_support from multiprocessing import freeze_support
freeze_support() freeze_support()
args = parse_args() args = parse_args()
# Save args to environment variable for child processes # Save args to environment variable for child processes
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args))
# Configure logging before starting uvicorn # Configure logging before starting uvicorn
configure_logging() configure_logging()
display_splash_screen(args) display_splash_screen(args)
uvicorn_config = { uvicorn_config = {
"app": "lightrag.api.lightrag_server:get_application", "app": "lightrag.api.lightrag_server:get_application",
"factory": True, "factory": True,

View File

@@ -375,62 +375,70 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents""" """Background task to scan and index documents"""
scan_progress = get_scan_progress() scan_progress = get_scan_progress()
scan_lock = get_scan_lock() scan_lock = get_scan_lock()
# Initialize scan_progress if not already initialized # Initialize scan_progress if not already initialized
if not scan_progress: if not scan_progress:
scan_progress.update({ scan_progress.update(
"is_scanning": False, {
"current_file": "", "is_scanning": False,
"indexed_count": 0, "current_file": "",
"total_files": 0, "indexed_count": 0,
"progress": 0, "total_files": 0,
}) "progress": 0,
}
)
with scan_lock: with scan_lock:
if scan_progress.get("is_scanning", False): if scan_progress.get("is_scanning", False):
ASCIIColors.info( ASCIIColors.info("Skip document scanning(another scanning is active)")
"Skip document scanning(another scanning is active)"
)
return return
scan_progress.update({ scan_progress.update(
"is_scanning": True, {
"current_file": "", "is_scanning": True,
"indexed_count": 0, "current_file": "",
"total_files": 0, "indexed_count": 0,
"progress": 0, "total_files": 0,
}) "progress": 0,
}
)
try: try:
new_files = doc_manager.scan_directory_for_new_files() new_files = doc_manager.scan_directory_for_new_files()
total_files = len(new_files) total_files = len(new_files)
scan_progress.update({ scan_progress.update(
"current_file": "", {
"total_files": total_files, "current_file": "",
"indexed_count": 0, "total_files": total_files,
"progress": 0, "indexed_count": 0,
}) "progress": 0,
}
)
logging.info(f"Found {total_files} new files to index.") logging.info(f"Found {total_files} new files to index.")
for idx, file_path in enumerate(new_files): for idx, file_path in enumerate(new_files):
try: try:
progress = (idx / total_files * 100) if total_files > 0 else 0 progress = (idx / total_files * 100) if total_files > 0 else 0
scan_progress.update({ scan_progress.update(
"current_file": os.path.basename(file_path), {
"indexed_count": idx, "current_file": os.path.basename(file_path),
"progress": progress, "indexed_count": idx,
}) "progress": progress,
}
)
await pipeline_index_file(rag, file_path) await pipeline_index_file(rag, file_path)
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0 progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
scan_progress.update({ scan_progress.update(
"current_file": os.path.basename(file_path), {
"indexed_count": idx + 1, "current_file": os.path.basename(file_path),
"progress": progress, "indexed_count": idx + 1,
}) "progress": progress,
}
)
except Exception as e: except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}") logging.error(f"Error indexing file {file_path}: {str(e)}")
@@ -438,13 +446,15 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
except Exception as e: except Exception as e:
logging.error(f"Error during scanning process: {str(e)}") logging.error(f"Error during scanning process: {str(e)}")
finally: finally:
scan_progress.update({ scan_progress.update(
"is_scanning": False, {
"current_file": "", "is_scanning": False,
"indexed_count": 0, "current_file": "",
"total_files": 0, "indexed_count": 0,
"progress": 0, "total_files": 0,
}) "progress": 0,
}
)
def create_document_routes( def create_document_routes(

View File

@@ -433,7 +433,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.white(" └─ Document Status Storage: ", end="") ASCIIColors.white(" └─ Document Status Storage: ", end="")
ASCIIColors.yellow(f"{args.doc_status_storage}") ASCIIColors.yellow(f"{args.doc_status_storage}")
# Server Status # Server Status
ASCIIColors.green("\n✨ Server starting up...\n") ASCIIColors.green("\n✨ Server starting up...\n")

View File

@@ -8,14 +8,19 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
from lightrag.utils import logger,compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess from .shared_storage import (
get_namespace_data,
get_storage_lock,
get_namespace_object,
is_multiprocess,
)
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
pm.install("faiss") pm.install("faiss")
import faiss # type: ignore import faiss # type: ignore
@final @final
@@ -46,10 +51,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Embedding dimension (e.g. 768) must match your embedding function # Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
self._index = get_namespace_object('faiss_indices') self._index = get_namespace_object("faiss_indices")
self._id_to_meta = get_namespace_data('faiss_meta') self._id_to_meta = get_namespace_data("faiss_meta")
with self._storage_lock: with self._storage_lock:
if is_multiprocess: if is_multiprocess:
if self._index.value is None: if self._index.value is None:
@@ -68,7 +73,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._id_to_meta.update({}) self._id_to_meta.update({})
self._load_faiss_index() self._load_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -168,7 +172,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Perform the similarity search # Perform the similarity search
with self._storage_lock: with self._storage_lock:
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k) distances, indices = (
self._index.value if is_multiprocess else self._index
).search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
@@ -232,7 +238,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
with self._storage_lock: with self._storage_lock:
relations = [] relations = []
for fid, meta in self._id_to_meta.items(): for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: if (
meta.get("src_id") == entity_name
or meta.get("tgt_id") == entity_name
):
relations.append(fid) relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}") logger.debug(f"Found {len(relations)} relations for {entity_name}")
@@ -292,7 +301,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
Save the current Faiss index + metadata to disk so it can persist across runs. Save the current Faiss index + metadata to disk so it can persist across runs.
""" """
with self._storage_lock: with self._storage_lock:
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file) faiss.write_index(
self._index.value if is_multiprocess else self._index,
self._faiss_index_file,
)
# Save metadata dict to JSON. Convert all keys to strings for JSON storage. # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
@@ -320,7 +332,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index.value = loaded_index self._index.value = loaded_index
else: else:
self._index = loaded_index self._index = loaded_index
# Load metadata # Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f: with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f) stored_dict = json.load(f)

View File

@@ -26,7 +26,6 @@ class JsonKVStorage(BaseKVStorage):
self._data: dict[str, Any] = load_json(self._file_name) or {} self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏 # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
with self._storage_lock: with self._storage_lock:

View File

@@ -25,7 +25,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -39,22 +39,28 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = get_namespace_object(self.namespace) self._client = get_namespace_object(self.namespace)
with self._storage_lock: with self._storage_lock:
if is_multiprocess: if is_multiprocess:
if self._client.value is None: if self._client.value is None:
self._client.value = NanoVectorDB( self._client.value = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
) )
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
else: else:
if self._client is None: if self._client is None:
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
) )
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
def _get_client(self): def _get_client(self):
"""Get the appropriate client instance based on multiprocess mode""" """Get the appropriate client instance based on multiprocess mode"""
@@ -104,7 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Execute embedding outside of lock to avoid long lock times # Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
with self._storage_lock: with self._storage_lock:
client = self._get_client() client = self._get_client()
results = client.query( results = client.query(
@@ -150,7 +156,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.debug( logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"Attempting to delete entity {entity_name} with ID {entity_id}"
) )
with self._storage_lock: with self._storage_lock:
client = self._get_client() client = self._get_client()
# Check if the entity exists # Check if the entity exists
@@ -172,7 +178,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
for dp in storage["data"] for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
] ]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}") logger.debug(
f"Found {len(relations)} relations for entity {entity_name}"
)
ids_to_delete = [relation["__id__"] for relation in relations] ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete: if ids_to_delete:

View File

@@ -78,29 +78,33 @@ class NetworkXStorage(BaseGraphStorage):
with self._storage_lock: with self._storage_lock:
if is_multiprocess: if is_multiprocess:
if self._graph.value is None: if self._graph.value is None:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
)
self._graph.value = preloaded_graph or nx.Graph() self._graph.value = preloaded_graph or nx.Graph()
if preloaded_graph: if preloaded_graph:
logger.info( logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
) )
else: else:
logger.info("Created new empty graph") logger.info("Created new empty graph")
else: else:
if self._graph is None: if self._graph is None:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
)
self._graph = preloaded_graph or nx.Graph() self._graph = preloaded_graph or nx.Graph()
if preloaded_graph: if preloaded_graph:
logger.info( logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
) )
else: else:
logger.info("Created new empty graph") logger.info("Created new empty graph")
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
def _get_graph(self): def _get_graph(self):
"""Get the appropriate graph instance based on multiprocess mode""" """Get the appropriate graph instance based on multiprocess mode"""
if is_multiprocess: if is_multiprocess:
@@ -248,11 +252,13 @@ class NetworkXStorage(BaseGraphStorage):
with self._storage_lock: with self._storage_lock:
graph = self._get_graph() graph = self._get_graph()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # For "*", return the entire graph including all nodes and edges
subgraph = graph.copy() # Create a copy to avoid modifying the original graph subgraph = (
graph.copy()
) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id (partial match) # Find nodes with matching node id (partial match)
nodes_to_explore = [] nodes_to_explore = []
@@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage):
if len(subgraph.nodes()) > max_graph_nodes: if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ top_nodes = sorted(
:max_graph_nodes node_degrees.items(), key=lambda x: x[1], reverse=True
] )[:max_graph_nodes]
top_node_ids = [node[0] for node in top_nodes] top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes # Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids) subgraph = subgraph.subgraph(top_node_ids)

View File

@@ -17,106 +17,125 @@ _shared_dicts: Optional[Dict[str, Any]] = {}
_share_objects: Optional[Dict[str, Any]] = {} _share_objects: Optional[Dict[str, Any]] = {}
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
def initialize_share_data(): def initialize_share_data():
"""Initialize shared data, only called if multiple processes where workers > 1""" """Initialize shared data, only called if multiple processes where workers > 1"""
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
is_multiprocess = True is_multiprocess = True
logger.info(f"Process {os.getpid()} initializing shared storage") logger.info(f"Process {os.getpid()} initializing shared storage")
# Initialize manager # Initialize manager
if _manager is None: if _manager is None:
_manager = Manager() _manager = Manager()
logger.info(f"Process {os.getpid()} created manager") logger.info(f"Process {os.getpid()} created manager")
# Create shared dictionaries with manager # Create shared dictionaries with manager
_shared_dicts = _manager.dict() _shared_dicts = _manager.dict()
_share_objects = _manager.dict() _share_objects = _manager.dict()
_init_flags = _manager.dict() # 使用共享字典存储初始化标志 _init_flags = _manager.dict() # 使用共享字典存储初始化标志
logger.info(f"Process {os.getpid()} created shared dictionaries") logger.info(f"Process {os.getpid()} created shared dictionaries")
def try_initialize_namespace(namespace: str) -> bool: def try_initialize_namespace(namespace: str) -> bool:
""" """
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。 尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
使用共享字典的原子操作确保只有一个进程能成功初始化。 使用共享字典的原子操作确保只有一个进程能成功初始化。
""" """
global _init_flags, _manager global _init_flags, _manager
if is_multiprocess: if is_multiprocess:
if _init_flags is None: if _init_flags is None:
raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.") raise RuntimeError(
"Shared storage not initialized. Call initialize_share_data() first."
)
else: else:
if _init_flags is None: if _init_flags is None:
_init_flags = {} _init_flags = {}
logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
# 使用全局锁保护共享字典的访问 # 使用全局锁保护共享字典的访问
with _get_global_lock(): with _get_global_lock():
# 检查是否已经初始化 # 检查是否已经初始化
if namespace not in _init_flags: if namespace not in _init_flags:
# 设置初始化标志 # 设置初始化标志
_init_flags[namespace] = True _init_flags[namespace] = True
logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}") logger.info(
f"Process {os.getpid()} ready to initialize namespace {namespace}"
)
return True return True
logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized") logger.info(
f"Process {os.getpid()} found namespace {namespace} already initialized"
)
return False return False
def _get_global_lock() -> LockType: def _get_global_lock() -> LockType:
global _global_lock, is_multiprocess, _manager global _global_lock, is_multiprocess, _manager
if _global_lock is None: if _global_lock is None:
if is_multiprocess: if is_multiprocess:
_global_lock = _manager.Lock() # Use manager for lock _global_lock = _manager.Lock() # Use manager for lock
else: else:
_global_lock = ThreadLock() _global_lock = ThreadLock()
return _global_lock return _global_lock
def get_storage_lock() -> LockType: def get_storage_lock() -> LockType:
"""return storage lock for data consistency""" """return storage lock for data consistency"""
return _get_global_lock() return _get_global_lock()
def get_scan_lock() -> LockType: def get_scan_lock() -> LockType:
"""return scan_progress lock for data consistency""" """return scan_progress lock for data consistency"""
return get_storage_lock() return get_storage_lock()
def get_namespace_object(namespace: str) -> Any: def get_namespace_object(namespace: str) -> Any:
"""Get an object for specific namespace""" """Get an object for specific namespace"""
global _share_objects, is_multiprocess, _manager global _share_objects, is_multiprocess, _manager
if is_multiprocess and not _manager: if is_multiprocess and not _manager:
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") raise RuntimeError(
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
)
if namespace not in _share_objects: if namespace not in _share_objects:
lock = _get_global_lock() lock = _get_global_lock()
with lock: with lock:
if namespace not in _share_objects: if namespace not in _share_objects:
if is_multiprocess: if is_multiprocess:
_share_objects[namespace] = _manager.Value('O', None) _share_objects[namespace] = _manager.Value("O", None)
else: else:
_share_objects[namespace] = None _share_objects[namespace] = None
return _share_objects[namespace] return _share_objects[namespace]
# 移除不再使用的函数 # 移除不再使用的函数
def get_namespace_data(namespace: str) -> Dict[str, Any]: def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get storage space for specific storage type(namespace)""" """get storage space for specific storage type(namespace)"""
global _shared_dicts, is_multiprocess, _manager global _shared_dicts, is_multiprocess, _manager
if is_multiprocess and not _manager: if is_multiprocess and not _manager:
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") raise RuntimeError(
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
)
if namespace not in _shared_dicts: if namespace not in _shared_dicts:
lock = _get_global_lock() lock = _get_global_lock()
with lock: with lock:
if namespace not in _shared_dicts: if namespace not in _shared_dicts:
_shared_dicts[namespace] = {} _shared_dicts[namespace] = {}
return _shared_dicts[namespace] return _shared_dicts[namespace]
def get_scan_progress() -> Dict[str, Any]: def get_scan_progress() -> Dict[str, Any]:
"""get storage space for document scanning progress data""" """get storage space for document scanning progress data"""
return get_namespace_data('scan_progress') return get_namespace_data("scan_progress")

View File

@@ -266,7 +266,7 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self): def __post_init__(self):
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path, self.log_level) set_logger(self.log_file_path, self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")

View File

@@ -55,6 +55,7 @@ def set_verbose_debug(enabled: bool):
global VERBOSE_DEBUG global VERBOSE_DEBUG
VERBOSE_DEBUG = enabled VERBOSE_DEBUG = enabled
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
# Initialize logger # Initialize logger
@@ -100,6 +101,7 @@ class UnlimitedSemaphore:
ENCODER = None ENCODER = None
@dataclass @dataclass
class EmbeddingFunc: class EmbeddingFunc:
embedding_dim: int embedding_dim: int