Optimization logic

This commit is contained in:
jin
2024-11-25 13:29:55 +08:00
parent 662303f605
commit 89c2de54a2
10 changed files with 342 additions and 423 deletions

View File

@@ -1,11 +1,16 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional
from typing import Optional,Any
from fastapi.responses import JSONResponse
import sys
import os
import sys, os
print(os.getcwd())
from pathlib import Path
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
import asyncio
import nest_asyncio
@@ -13,15 +18,11 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
@@ -37,18 +38,16 @@ APIKEY = "ocigenerativeai"
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -78,10 +77,10 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1]
return embedding_dim
async def init():
# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
embedding_dimension = 1024 #await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
@@ -89,36 +88,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "",
"wallet_location": "",
"wallet_password": "",
"workspace": "",
} # specify which docs you want to store and query
)
oracle_db = OracleDB(config={
"user":"",
"password":"",
"dsn":"",
"config_dir":"path_to_config_dir",
"wallet_location":"path_to_wallet_location",
"wallet_password":"wallet_password",
"workspace":"company"
} # specify which docs you want to store and query
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
@@ -128,6 +127,17 @@ async def init():
return rag
# Extract and Insert into LightRAG storage
#with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
# modes = ["naive", "local", "global", "hybrid"]
# for mode in modes:
# print("="*20, mode, "="*20)
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
# print("-"*100, "\n")
# Data models
@@ -135,7 +145,10 @@ class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel):
text: str
@@ -143,7 +156,7 @@ class InsertRequest(BaseModel):
class Response(BaseModel):
status: str
data: Optional[str] = None
data: Optional[Any] = None
message: Optional[str] = None
@@ -151,7 +164,6 @@ class Response(BaseModel):
rag = None # 定义为全局对象
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
@@ -160,24 +172,39 @@ async def lifespan(app: FastAPI):
yield
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
#try:
# loop = asyncio.get_event_loop()
result = await rag.aquery(
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k
),
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return Response(status="success", data=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@app.post("/insert", response_model=Response)
@@ -220,7 +247,7 @@ async def health_check():
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)
uvicorn.run(app, host="127.0.0.1", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
@@ -237,4 +264,4 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -97,6 +97,8 @@ async def main():
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool