feat: 新增ini文件读取数据库配置方式,方便生产环境,修改Lightrag ainsert方法_add_doc_keys获取方式,原来只过滤存在的,但这会让失败的文档无法再次存储,新增--chunk_size和--chunk_overlap_size方便生产环境,新增llm_binding:openai-ollama 方便用openai的同时使用ollama embedding

This commit is contained in:
hyb
2025-01-23 22:58:57 +08:00
parent 3c5ced835e
commit ff71952c8c
5 changed files with 111 additions and 7 deletions

13
config.ini Normal file
View File

@@ -0,0 +1,13 @@
[redis]
uri = redis://localhost:6379
[neo4j]
uri = #
username = neo4j
password = 12345678
[milvus]
uri = #
user = root
password = Milvus
db_name = lightrag

View File

@@ -20,6 +20,7 @@ import shutil
import aiofiles
from ascii_colors import trace_exception, ASCIIColors
import os
import configparser
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
@@ -57,6 +58,52 @@ LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
# 读取配置文件
config = configparser.ConfigParser()
config.read("config.ini")
# Redis 配置
redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri:
os.environ["REDIS_URI"] = redis_uri
KV_STORAGE = "RedisKVStorage"
DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j 配置
neo4j_uri = config.get("neo4j", "uri", fallback=None)
neo4j_username = config.get("neo4j", "username", fallback=None)
neo4j_password = config.get("neo4j", "password", fallback=None)
if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password
GRAPH_STORAGE = "Neo4JStorage"
# Milvus 配置
milvus_uri = config.get("milvus", "uri", fallback=None)
milvus_user = config.get("milvus", "user", fallback=None)
milvus_password = config.get("milvus", "password", fallback=None)
milvus_db_name = config.get("milvus", "db_name", fallback=None)
if milvus_uri:
os.environ["MILVUS_URI"] = milvus_uri
os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name
VECTOR_STORAGE = "MilvusVectorDBStorge"
# MongoDB 配置
mongo_uri = config.get("mongodb", "uri", fallback=None)
mongo_database = config.get("mongodb", "LightRAG", fallback=None)
if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
KV_STORAGE = "MongoKVStorage"
DOC_STATUS_STORAGE = "MongoKVStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
@@ -337,6 +384,18 @@ def parse_args() -> argparse.Namespace:
help="Embedding model name (default: from env or bge-m3:latest)",
)
parser.add_argument(
"--chunk_size",
default=1200,
help="chunk token size default 1200",
)
parser.add_argument(
"--chunk_overlap_size",
default=100,
help="chunk token size default 1200",
)
def timeout_type(value):
if value is None or value == "None":
return None
@@ -551,7 +610,7 @@ def get_api_key_dependency(api_key: Optional[str]):
def create_app(args):
# Verify that bindings arer correctly setup
if args.llm_binding not in ["lollms", "ollama", "openai"]:
if args.llm_binding not in ["lollms", "ollama", "openai", "openai-ollama"]:
raise Exception("llm binding not supported")
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
@@ -692,22 +751,32 @@ def create_app(args):
)
# Initialize RAG
if args.llm_binding in ["lollms", "ollama"]:
if args.llm_binding in ["lollms", "ollama", "openai-ollama"]:
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_model_complete,
else ollama_model_complete
if args.llm_binding == "ollama"
else openai_alike_model_complete,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
"api_key": args.llm_binding_api_key,
},
}
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {},
embedding_func=embedding_func,
kv_storage=KV_STORAGE,
graph_storage=GRAPH_STORAGE,
vector_storage=VECTOR_STORAGE,
doc_status_storage=DOC_STATUS_STORAGE,
)
else:
rag = LightRAG(
@@ -715,7 +784,13 @@ def create_app(args):
llm_model_func=azure_openai_model_complete
if args.llm_binding == "azure_openai"
else openai_alike_model_complete,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
embedding_func=embedding_func,
kv_storage=KV_STORAGE,
graph_storage=GRAPH_STORAGE,
vector_storage=VECTOR_STORAGE,
doc_status_storage=DOC_STATUS_STORAGE,
)
async def index_file(file_path: Union[str, Path]) -> None:

View File

@@ -361,7 +361,13 @@ class LightRAG:
}
# 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
# _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
_add_doc_keys = {
doc_id
for doc_id in new_docs.keys()
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
or current_doc["status"] == DocStatus.FAILED
}
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
@@ -573,7 +579,7 @@ class LightRAG:
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
logger.info(
f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents"
f"Skipping {len(new_docs) - len(_not_stored_doc_keys)} already existing documents"
)
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
@@ -618,7 +624,7 @@ class LightRAG:
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(),
desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}",
desc=f"Level 1 - Spliting doc in batch {i // batch_size + 1}",
):
try:
# Generate chunks from document

View File

@@ -445,6 +445,9 @@ class JsonDocStatusStorage(DocStatusStorage):
await self.index_done_callback()
return data
async def get_by_id(self, id: str):
return self._data.get(id)
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
"""Get document status by ID"""
return self._data.get(doc_id)

View File

@@ -4,6 +4,7 @@ aiofiles
aiohttp
aioredis
asyncpg
configparser
# database packages
graspologic
@@ -17,13 +18,19 @@ numpy
ollama
openai
oracledb
pipmaster
psycopg-pool
psycopg[binary,pool]
pydantic
pymilvus
pymongo
pymysql
PyPDF2
python-docx
python-dotenv
python-pptx
pyvis
setuptools
# lmdeploy[all]