Fix linting
This commit is contained in:
@@ -54,6 +54,7 @@ config.read("config.ini")
|
|||||||
|
|
||||||
class LightragPathFilter(logging.Filter):
|
class LightragPathFilter(logging.Filter):
|
||||||
"""Filter for lightrag logger to filter out frequent path access logs"""
|
"""Filter for lightrag logger to filter out frequent path access logs"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Define paths to be filtered
|
# Define paths to be filtered
|
||||||
@@ -90,11 +91,13 @@ def create_app(args):
|
|||||||
# Initialize verbose debug setting
|
# Initialize verbose debug setting
|
||||||
# Can not use the logger at the top of this module when workers > 1
|
# Can not use the logger at the top of this module when workers > 1
|
||||||
from lightrag.utils import set_verbose_debug, logger
|
from lightrag.utils import set_verbose_debug, logger
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logger.setLevel(getattr(logging, args.log_level))
|
logger.setLevel(getattr(logging, args.log_level))
|
||||||
set_verbose_debug(args.verbose)
|
set_verbose_debug(args.verbose)
|
||||||
|
|
||||||
from lightrag.kg.shared_storage import is_multiprocess
|
from lightrag.kg.shared_storage import is_multiprocess
|
||||||
|
|
||||||
logger.info(f"==== Multi-processor mode: {is_multiprocess} ====")
|
logger.info(f"==== Multi-processor mode: {is_multiprocess} ====")
|
||||||
|
|
||||||
# Verify that bindings are correctly setup
|
# Verify that bindings are correctly setup
|
||||||
@@ -147,9 +150,7 @@ def create_app(args):
|
|||||||
# Auto scan documents if enabled
|
# Auto scan documents if enabled
|
||||||
if args.auto_scan_at_startup:
|
if args.auto_scan_at_startup:
|
||||||
# Create background task
|
# Create background task
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
||||||
run_scanning_process(rag, doc_manager)
|
|
||||||
)
|
|
||||||
app.state.background_tasks.add(task)
|
app.state.background_tasks.add(task)
|
||||||
task.add_done_callback(app.state.background_tasks.discard)
|
task.add_done_callback(app.state.background_tasks.discard)
|
||||||
|
|
||||||
@@ -413,15 +414,17 @@ def get_application():
|
|||||||
configure_logging()
|
configure_logging()
|
||||||
|
|
||||||
# Get args from environment variable
|
# Get args from environment variable
|
||||||
args_json = os.environ.get('LIGHTRAG_ARGS')
|
args_json = os.environ.get("LIGHTRAG_ARGS")
|
||||||
if not args_json:
|
if not args_json:
|
||||||
args = parse_args() # Fallback to parsing args if env var not set
|
args = parse_args() # Fallback to parsing args if env var not set
|
||||||
else:
|
else:
|
||||||
import types
|
import types
|
||||||
|
|
||||||
args = types.SimpleNamespace(**json.loads(args_json))
|
args = types.SimpleNamespace(**json.loads(args_json))
|
||||||
|
|
||||||
if args.workers > 1:
|
if args.workers > 1:
|
||||||
from lightrag.kg.shared_storage import initialize_share_data
|
from lightrag.kg.shared_storage import initialize_share_data
|
||||||
|
|
||||||
initialize_share_data()
|
initialize_share_data()
|
||||||
|
|
||||||
return create_app(args)
|
return create_app(args)
|
||||||
@@ -436,7 +439,8 @@ def configure_logging():
|
|||||||
logger.filters = []
|
logger.filters = []
|
||||||
|
|
||||||
# Configure basic logging
|
# Configure basic logging
|
||||||
logging.config.dictConfig({
|
logging.config.dictConfig(
|
||||||
|
{
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"disable_existing_loggers": False,
|
"disable_existing_loggers": False,
|
||||||
"formatters": {
|
"formatters": {
|
||||||
@@ -470,22 +474,24 @@ def configure_logging():
|
|||||||
"()": "lightrag.api.lightrag_server.LightragPathFilter",
|
"()": "lightrag.api.lightrag_server.LightragPathFilter",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from multiprocessing import freeze_support
|
from multiprocessing import freeze_support
|
||||||
|
|
||||||
freeze_support()
|
freeze_support()
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
# Save args to environment variable for child processes
|
# Save args to environment variable for child processes
|
||||||
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
|
os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args))
|
||||||
|
|
||||||
# Configure logging before starting uvicorn
|
# Configure logging before starting uvicorn
|
||||||
configure_logging()
|
configure_logging()
|
||||||
|
|
||||||
display_splash_screen(args)
|
display_splash_screen(args)
|
||||||
|
|
||||||
|
|
||||||
uvicorn_config = {
|
uvicorn_config = {
|
||||||
"app": "lightrag.api.lightrag_server:get_application",
|
"app": "lightrag.api.lightrag_server:get_application",
|
||||||
"factory": True,
|
"factory": True,
|
||||||
|
@@ -381,56 +381,64 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|||||||
|
|
||||||
# Initialize scan_progress if not already initialized
|
# Initialize scan_progress if not already initialized
|
||||||
if not scan_progress:
|
if not scan_progress:
|
||||||
scan_progress.update({
|
scan_progress.update(
|
||||||
|
{
|
||||||
"is_scanning": False,
|
"is_scanning": False,
|
||||||
"current_file": "",
|
"current_file": "",
|
||||||
"indexed_count": 0,
|
"indexed_count": 0,
|
||||||
"total_files": 0,
|
"total_files": 0,
|
||||||
"progress": 0,
|
"progress": 0,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with scan_lock:
|
with scan_lock:
|
||||||
if scan_progress.get("is_scanning", False):
|
if scan_progress.get("is_scanning", False):
|
||||||
ASCIIColors.info(
|
ASCIIColors.info("Skip document scanning(another scanning is active)")
|
||||||
"Skip document scanning(another scanning is active)"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
scan_progress.update({
|
scan_progress.update(
|
||||||
|
{
|
||||||
"is_scanning": True,
|
"is_scanning": True,
|
||||||
"current_file": "",
|
"current_file": "",
|
||||||
"indexed_count": 0,
|
"indexed_count": 0,
|
||||||
"total_files": 0,
|
"total_files": 0,
|
||||||
"progress": 0,
|
"progress": 0,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_files = doc_manager.scan_directory_for_new_files()
|
new_files = doc_manager.scan_directory_for_new_files()
|
||||||
total_files = len(new_files)
|
total_files = len(new_files)
|
||||||
scan_progress.update({
|
scan_progress.update(
|
||||||
|
{
|
||||||
"current_file": "",
|
"current_file": "",
|
||||||
"total_files": total_files,
|
"total_files": total_files,
|
||||||
"indexed_count": 0,
|
"indexed_count": 0,
|
||||||
"progress": 0,
|
"progress": 0,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"Found {total_files} new files to index.")
|
logging.info(f"Found {total_files} new files to index.")
|
||||||
for idx, file_path in enumerate(new_files):
|
for idx, file_path in enumerate(new_files):
|
||||||
try:
|
try:
|
||||||
progress = (idx / total_files * 100) if total_files > 0 else 0
|
progress = (idx / total_files * 100) if total_files > 0 else 0
|
||||||
scan_progress.update({
|
scan_progress.update(
|
||||||
|
{
|
||||||
"current_file": os.path.basename(file_path),
|
"current_file": os.path.basename(file_path),
|
||||||
"indexed_count": idx,
|
"indexed_count": idx,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await pipeline_index_file(rag, file_path)
|
await pipeline_index_file(rag, file_path)
|
||||||
|
|
||||||
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
|
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
|
||||||
scan_progress.update({
|
scan_progress.update(
|
||||||
|
{
|
||||||
"current_file": os.path.basename(file_path),
|
"current_file": os.path.basename(file_path),
|
||||||
"indexed_count": idx + 1,
|
"indexed_count": idx + 1,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||||
@@ -438,13 +446,15 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during scanning process: {str(e)}")
|
logging.error(f"Error during scanning process: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
scan_progress.update({
|
scan_progress.update(
|
||||||
|
{
|
||||||
"is_scanning": False,
|
"is_scanning": False,
|
||||||
"current_file": "",
|
"current_file": "",
|
||||||
"indexed_count": 0,
|
"indexed_count": 0,
|
||||||
"total_files": 0,
|
"total_files": 0,
|
||||||
"progress": 0,
|
"progress": 0,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_document_routes(
|
def create_document_routes(
|
||||||
|
@@ -433,7 +433,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||||||
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
||||||
|
|
||||||
|
|
||||||
# Server Status
|
# Server Status
|
||||||
ASCIIColors.green("\n✨ Server starting up...\n")
|
ASCIIColors.green("\n✨ Server starting up...\n")
|
||||||
|
|
||||||
|
@@ -10,7 +10,12 @@ import pipmaster as pm
|
|||||||
|
|
||||||
from lightrag.utils import logger, compute_mdhash_id
|
from lightrag.utils import logger, compute_mdhash_id
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
|
from .shared_storage import (
|
||||||
|
get_namespace_data,
|
||||||
|
get_storage_lock,
|
||||||
|
get_namespace_object,
|
||||||
|
is_multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
if not pm.is_installed("faiss"):
|
if not pm.is_installed("faiss"):
|
||||||
pm.install("faiss")
|
pm.install("faiss")
|
||||||
@@ -47,8 +52,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._dim = self.embedding_func.embedding_dim
|
self._dim = self.embedding_func.embedding_dim
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = get_storage_lock()
|
||||||
|
|
||||||
self._index = get_namespace_object('faiss_indices')
|
self._index = get_namespace_object("faiss_indices")
|
||||||
self._id_to_meta = get_namespace_data('faiss_meta')
|
self._id_to_meta = get_namespace_data("faiss_meta")
|
||||||
|
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
if is_multiprocess:
|
if is_multiprocess:
|
||||||
@@ -68,7 +73,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._id_to_meta.update({})
|
self._id_to_meta.update({})
|
||||||
self._load_faiss_index()
|
self._load_faiss_index()
|
||||||
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update vectors in the Faiss index.
|
Insert or update vectors in the Faiss index.
|
||||||
@@ -168,7 +172,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
# Perform the similarity search
|
# Perform the similarity search
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
|
distances, indices = (
|
||||||
|
self._index.value if is_multiprocess else self._index
|
||||||
|
).search(embedding, top_k)
|
||||||
|
|
||||||
distances = distances[0]
|
distances = distances[0]
|
||||||
indices = indices[0]
|
indices = indices[0]
|
||||||
@@ -232,7 +238,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
relations = []
|
relations = []
|
||||||
for fid, meta in self._id_to_meta.items():
|
for fid, meta in self._id_to_meta.items():
|
||||||
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
if (
|
||||||
|
meta.get("src_id") == entity_name
|
||||||
|
or meta.get("tgt_id") == entity_name
|
||||||
|
):
|
||||||
relations.append(fid)
|
relations.append(fid)
|
||||||
|
|
||||||
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
||||||
@@ -292,7 +301,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Save the current Faiss index + metadata to disk so it can persist across runs.
|
Save the current Faiss index + metadata to disk so it can persist across runs.
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
|
faiss.write_index(
|
||||||
|
self._index.value if is_multiprocess else self._index,
|
||||||
|
self._faiss_index_file,
|
||||||
|
)
|
||||||
|
|
||||||
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
||||||
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
||||||
|
@@ -26,7 +26,6 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||||
|
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
|
@@ -46,15 +46,21 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
if is_multiprocess:
|
if is_multiprocess:
|
||||||
if self._client.value is None:
|
if self._client.value is None:
|
||||||
self._client.value = NanoVectorDB(
|
self._client.value = NanoVectorDB(
|
||||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
self.embedding_func.embedding_dim,
|
||||||
|
storage_file=self._client_file_name,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized vector DB client for namespace {self.namespace}"
|
||||||
)
|
)
|
||||||
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
|
|
||||||
else:
|
else:
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
self._client = NanoVectorDB(
|
self._client = NanoVectorDB(
|
||||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
self.embedding_func.embedding_dim,
|
||||||
|
storage_file=self._client_file_name,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized vector DB client for namespace {self.namespace}"
|
||||||
)
|
)
|
||||||
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
|
|
||||||
|
|
||||||
def _get_client(self):
|
def _get_client(self):
|
||||||
"""Get the appropriate client instance based on multiprocess mode"""
|
"""Get the appropriate client instance based on multiprocess mode"""
|
||||||
@@ -172,7 +178,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
for dp in storage["data"]
|
for dp in storage["data"]
|
||||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||||
]
|
]
|
||||||
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
logger.debug(
|
||||||
|
f"Found {len(relations)} relations for entity {entity_name}"
|
||||||
|
)
|
||||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||||
|
|
||||||
if ids_to_delete:
|
if ids_to_delete:
|
||||||
|
@@ -78,7 +78,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
if is_multiprocess:
|
if is_multiprocess:
|
||||||
if self._graph.value is None:
|
if self._graph.value is None:
|
||||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
preloaded_graph = NetworkXStorage.load_nx_graph(
|
||||||
|
self._graphml_xml_file
|
||||||
|
)
|
||||||
self._graph.value = preloaded_graph or nx.Graph()
|
self._graph.value = preloaded_graph or nx.Graph()
|
||||||
if preloaded_graph:
|
if preloaded_graph:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -88,7 +90,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
logger.info("Created new empty graph")
|
logger.info("Created new empty graph")
|
||||||
else:
|
else:
|
||||||
if self._graph is None:
|
if self._graph is None:
|
||||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
preloaded_graph = NetworkXStorage.load_nx_graph(
|
||||||
|
self._graphml_xml_file
|
||||||
|
)
|
||||||
self._graph = preloaded_graph or nx.Graph()
|
self._graph = preloaded_graph or nx.Graph()
|
||||||
if preloaded_graph:
|
if preloaded_graph:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -252,7 +256,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
# Handle special case for "*" label
|
# Handle special case for "*" label
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
# For "*", return the entire graph including all nodes and edges
|
# For "*", return the entire graph including all nodes and edges
|
||||||
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
|
subgraph = (
|
||||||
|
graph.copy()
|
||||||
|
) # Create a copy to avoid modifying the original graph
|
||||||
else:
|
else:
|
||||||
# Find nodes with matching node id (partial match)
|
# Find nodes with matching node id (partial match)
|
||||||
nodes_to_explore = []
|
nodes_to_explore = []
|
||||||
@@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
if len(subgraph.nodes()) > max_graph_nodes:
|
if len(subgraph.nodes()) > max_graph_nodes:
|
||||||
origin_nodes = len(subgraph.nodes())
|
origin_nodes = len(subgraph.nodes())
|
||||||
node_degrees = dict(subgraph.degree())
|
node_degrees = dict(subgraph.degree())
|
||||||
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
top_nodes = sorted(
|
||||||
:max_graph_nodes
|
node_degrees.items(), key=lambda x: x[1], reverse=True
|
||||||
]
|
)[:max_graph_nodes]
|
||||||
top_node_ids = [node[0] for node in top_nodes]
|
top_node_ids = [node[0] for node in top_nodes]
|
||||||
# Create new subgraph with only top nodes
|
# Create new subgraph with only top nodes
|
||||||
subgraph = subgraph.subgraph(top_node_ids)
|
subgraph = subgraph.subgraph(top_node_ids)
|
||||||
|
@@ -17,6 +17,7 @@ _shared_dicts: Optional[Dict[str, Any]] = {}
|
|||||||
_share_objects: Optional[Dict[str, Any]] = {}
|
_share_objects: Optional[Dict[str, Any]] = {}
|
||||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||||
|
|
||||||
|
|
||||||
def initialize_share_data():
|
def initialize_share_data():
|
||||||
"""Initialize shared data, only called if multiple processes where workers > 1"""
|
"""Initialize shared data, only called if multiple processes where workers > 1"""
|
||||||
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
|
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
|
||||||
@@ -35,6 +36,7 @@ def initialize_share_data():
|
|||||||
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
|
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
|
||||||
logger.info(f"Process {os.getpid()} created shared dictionaries")
|
logger.info(f"Process {os.getpid()} created shared dictionaries")
|
||||||
|
|
||||||
|
|
||||||
def try_initialize_namespace(namespace: str) -> bool:
|
def try_initialize_namespace(namespace: str) -> bool:
|
||||||
"""
|
"""
|
||||||
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
|
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
|
||||||
@@ -44,7 +46,9 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
|
|
||||||
if is_multiprocess:
|
if is_multiprocess:
|
||||||
if _init_flags is None:
|
if _init_flags is None:
|
||||||
raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
|
raise RuntimeError(
|
||||||
|
"Shared storage not initialized. Call initialize_share_data() first."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if _init_flags is None:
|
if _init_flags is None:
|
||||||
_init_flags = {}
|
_init_flags = {}
|
||||||
@@ -57,12 +61,17 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
if namespace not in _init_flags:
|
if namespace not in _init_flags:
|
||||||
# 设置初始化标志
|
# 设置初始化标志
|
||||||
_init_flags[namespace] = True
|
_init_flags[namespace] = True
|
||||||
logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
|
logger.info(
|
||||||
|
f"Process {os.getpid()} ready to initialize namespace {namespace}"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
|
logger.info(
|
||||||
|
f"Process {os.getpid()} found namespace {namespace} already initialized"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _get_global_lock() -> LockType:
|
def _get_global_lock() -> LockType:
|
||||||
global _global_lock, is_multiprocess, _manager
|
global _global_lock, is_multiprocess, _manager
|
||||||
|
|
||||||
@@ -74,40 +83,49 @@ def _get_global_lock() -> LockType:
|
|||||||
|
|
||||||
return _global_lock
|
return _global_lock
|
||||||
|
|
||||||
|
|
||||||
def get_storage_lock() -> LockType:
|
def get_storage_lock() -> LockType:
|
||||||
"""return storage lock for data consistency"""
|
"""return storage lock for data consistency"""
|
||||||
return _get_global_lock()
|
return _get_global_lock()
|
||||||
|
|
||||||
|
|
||||||
def get_scan_lock() -> LockType:
|
def get_scan_lock() -> LockType:
|
||||||
"""return scan_progress lock for data consistency"""
|
"""return scan_progress lock for data consistency"""
|
||||||
return get_storage_lock()
|
return get_storage_lock()
|
||||||
|
|
||||||
|
|
||||||
def get_namespace_object(namespace: str) -> Any:
|
def get_namespace_object(namespace: str) -> Any:
|
||||||
"""Get an object for specific namespace"""
|
"""Get an object for specific namespace"""
|
||||||
global _share_objects, is_multiprocess, _manager
|
global _share_objects, is_multiprocess, _manager
|
||||||
|
|
||||||
if is_multiprocess and not _manager:
|
if is_multiprocess and not _manager:
|
||||||
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
raise RuntimeError(
|
||||||
|
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
|
||||||
|
)
|
||||||
|
|
||||||
if namespace not in _share_objects:
|
if namespace not in _share_objects:
|
||||||
lock = _get_global_lock()
|
lock = _get_global_lock()
|
||||||
with lock:
|
with lock:
|
||||||
if namespace not in _share_objects:
|
if namespace not in _share_objects:
|
||||||
if is_multiprocess:
|
if is_multiprocess:
|
||||||
_share_objects[namespace] = _manager.Value('O', None)
|
_share_objects[namespace] = _manager.Value("O", None)
|
||||||
else:
|
else:
|
||||||
_share_objects[namespace] = None
|
_share_objects[namespace] = None
|
||||||
|
|
||||||
return _share_objects[namespace]
|
return _share_objects[namespace]
|
||||||
|
|
||||||
|
|
||||||
# 移除不再使用的函数
|
# 移除不再使用的函数
|
||||||
|
|
||||||
|
|
||||||
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||||
"""get storage space for specific storage type(namespace)"""
|
"""get storage space for specific storage type(namespace)"""
|
||||||
global _shared_dicts, is_multiprocess, _manager
|
global _shared_dicts, is_multiprocess, _manager
|
||||||
|
|
||||||
if is_multiprocess and not _manager:
|
if is_multiprocess and not _manager:
|
||||||
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
raise RuntimeError(
|
||||||
|
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
|
||||||
|
)
|
||||||
|
|
||||||
if namespace not in _shared_dicts:
|
if namespace not in _shared_dicts:
|
||||||
lock = _get_global_lock()
|
lock = _get_global_lock()
|
||||||
@@ -117,6 +135,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
return _shared_dicts[namespace]
|
return _shared_dicts[namespace]
|
||||||
|
|
||||||
|
|
||||||
def get_scan_progress() -> Dict[str, Any]:
|
def get_scan_progress() -> Dict[str, Any]:
|
||||||
"""get storage space for document scanning progress data"""
|
"""get storage space for document scanning progress data"""
|
||||||
return get_namespace_data('scan_progress')
|
return get_namespace_data("scan_progress")
|
||||||
|
@@ -55,6 +55,7 @@ def set_verbose_debug(enabled: bool):
|
|||||||
global VERBOSE_DEBUG
|
global VERBOSE_DEBUG
|
||||||
VERBOSE_DEBUG = enabled
|
VERBOSE_DEBUG = enabled
|
||||||
|
|
||||||
|
|
||||||
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
||||||
|
|
||||||
# Initialize logger
|
# Initialize logger
|
||||||
@@ -100,6 +101,7 @@ class UnlimitedSemaphore:
|
|||||||
|
|
||||||
ENCODER = None
|
ENCODER = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
|
Reference in New Issue
Block a user