Update sample code for OpenAI and OpenAI compatible

This commit is contained in:
yangdx
2025-04-21 00:09:05 +08:00
parent 1a7b225e90
commit e0f0d23e5a
3 changed files with 148 additions and 102 deletions

View File

@@ -1,13 +1,83 @@
import os
import asyncio
import inspect
import logging
import logging.config
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.ollama import ollama_embed
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status
WORKING_DIR = "./dickens"
def configure_logging():
"""Configure logging for the application"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger_instance = logging.getLogger(logger_name)
logger_instance.handlers = []
logger_instance.filters = []
# Get log directory path from environment variable or use current directory
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(
os.path.join(log_dir, "lightrag_compatible_demo.log")
)
print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -16,22 +86,21 @@ async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
"deepseek-chat",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://api.deepseek.com",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
return await ollama_embed(
texts=texts,
embed_model="bge-m3:latest",
host="http://m4.lan.znipower.com:11434",
)
@@ -54,6 +123,12 @@ async def test_funcs():
# asyncio.run(test_funcs())
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
async def initialize_rag():
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
@@ -83,37 +158,66 @@ async def main():
await rag.ainsert(f.read())
# Perform naive search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
print("\n=====================")
print("Query mode: naive")
print("=====================")
resp = await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="naive", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
# Perform local search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
print("\n=====================")
print("Query mode: local")
print("=====================")
resp = await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="local", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
# Perform global search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
print("\n=====================")
print("Query mode: global")
print("=====================")
resp = await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="global", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
# Perform hybrid search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
print("\n=====================")
print("Query mode: hybrid")
print("=====================")
resp = await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
except Exception as e:
print(f"An error occurred: {e}")
finally:
if rag:
await rag.finalize_storages()
if __name__ == "__main__":
# Configure logging before running the main function
configure_logging()
asyncio.run(main())
print("\nDone!")

View File

@@ -1,72 +0,0 @@
import inspect
import os
import asyncio
from lightrag import LightRAG
from lightrag.llm import openai_complete, openai_embed
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
from lightrag import QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
# WorkingDir
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
WORKING_DIR = os.path.join(ROOT_DIR, "dickens")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
print(f"WorkingDir: {WORKING_DIR}")
api_key = "empty"
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=openai_complete,
llm_model_name="qwen2.5-14b-instruct@4bit",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: openai_embed(
texts=texts,
model="text-embedding-bge-m3",
base_url="http://127.0.0.1:1234/v1",
api_key=api_key,
),
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
resp = rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)
loop = always_get_an_event_loop()
if inspect.isasyncgen(resp):
loop.run_until_complete(print_stream(resp))
else:
print(resp)
if __name__ == "__main__":
main()

View File

@@ -9,9 +9,10 @@ from lightrag.utils import logger, set_verbose_debug
WORKING_DIR = "./dickens"
def configure_logging():
"""Configure logging for the application"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger_instance = logging.getLogger(logger_name)
@@ -65,12 +66,13 @@ def configure_logging():
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -97,6 +99,9 @@ async def main():
await rag.ainsert(f.read())
# Perform naive search
print("\n=====================")
print("Query mode: naive")
print("=====================")
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
@@ -104,6 +109,9 @@ async def main():
)
# Perform local search
print("\n=====================")
print("Query mode: local")
print("=====================")
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
@@ -111,6 +119,9 @@ async def main():
)
# Perform global search
print("\n=====================")
print("Query mode: global")
print("=====================")
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="global")
@@ -118,6 +129,9 @@ async def main():
)
# Perform hybrid search
print("\n=====================")
print("Query mode: hybrid")
print("=====================")
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")