Logic Optimization

This commit is contained in:
jin
2024-11-25 13:40:38 +08:00
parent bf5815be8f
commit 21f161390a
8 changed files with 185 additions and 136 deletions

View File

@@ -1,16 +1,14 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional,Any
from fastapi.responses import JSONResponse
from typing import Optional, Any
import sys
import os
import sys, os
print(os.getcwd())
from pathlib import Path
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
import asyncio
import nest_asyncio
@@ -18,10 +16,12 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues
@@ -47,7 +47,8 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -77,10 +78,10 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1]
return embedding_dim
async def init():
# Detect embedding dimension
embedding_dimension = 1024 #await get_embedding_dim()
embedding_dimension = 1024 # await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
@@ -88,36 +89,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "path_to_config_dir",
"wallet_location": "path_to_wallet_location",
"wallet_password": "wallet_password",
"workspace": "company",
} # specify which docs you want to store and query
)
oracle_db = OracleDB(config={
"user":"",
"password":"",
"dsn":"",
"config_dir":"path_to_config_dir",
"wallet_location":"path_to_wallet_location",
"wallet_password":"wallet_password",
"workspace":"company"
} # specify which docs you want to store and query
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
)
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
@@ -128,7 +129,7 @@ async def init():
# Extract and Insert into LightRAG storage
#with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
@@ -147,9 +148,11 @@ class QueryRequest(BaseModel):
only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel):
text: str
@@ -164,6 +167,7 @@ class Response(BaseModel):
rag = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
@@ -172,25 +176,28 @@ async def lifespan(app: FastAPI):
yield
app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
#try:
# loop = asyncio.get_event_loop()
# try:
# loop = asyncio.get_event_loop()
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k
),
)
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k,
),
)
return Response(status="success", data=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@@ -199,9 +206,9 @@ async def query_endpoint(request: QueryRequest):
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit)
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit)
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@@ -264,4 +271,4 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -97,8 +97,7 @@ async def main():
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
addon_params={"example_number": 1, "language": "Simplfied Chinese"},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool