Merge remote-tracking branch 'origin/main'

# Conflicts:
#	lightrag/llm.py
#	lightrag/operate.py
This commit is contained in:
yuanxiaobin
2024-12-06 15:06:00 +08:00
6 changed files with 198 additions and 6 deletions

View File

@@ -0,0 +1,140 @@
from datetime import datetime, timezone
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import inspect
import json
from pydantic import BaseModel
from typing import Optional
import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
import nest_asyncio
WORKING_DIR = "./dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:latest",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
# Data models
MODEL_NAME = "LightRAG:latest"
class Message(BaseModel):
role: Optional[str] = None
content: str
class OpenWebUIRequest(BaseModel):
stream: Optional[bool] = None
model: Optional[str] = None
messages: list[Message]
# API routes
@app.get("/")
async def index():
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
@app.get("/ollama/api/version")
async def ollama_version():
return {"version": "0.4.7"}
@app.get("/ollama/api/tags")
async def ollama_tags():
return {
"models": [
{
"name": MODEL_NAME,
"model": MODEL_NAME,
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
"size": 4683087332,
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M",
},
}
]
}
@app.post("/ollama/api/chat")
async def ollama_chat(request: OpenWebUIRequest):
resp = rag.query(
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
)
if inspect.isasyncgen(resp):
async def ollama_resp(chunks):
async for chunk in chunks:
yield (
json.dumps(
{
"model": MODEL_NAME,
"created_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
),
"message": {
"role": "assistant",
"content": chunk,
},
"done": False,
},
ensure_ascii=False,
).encode("utf-8")
+ b"\n"
) # the b"\n" is important
return StreamingResponse(ollama_resp(resp), media_type="application/json")
else:
return resp
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)

View File

@@ -1,4 +1,6 @@
import asyncio
import os
import inspect
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -49,3 +51,20 @@ print(
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)
# stream response
resp = rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)
async def print_stream(stream):
async for chunk in stream:
print(chunk, end="", flush=True)
if inspect.isasyncgen(resp):
asyncio.run(print_stream(resp))
else:
print(resp)

View File

@@ -19,6 +19,7 @@ class QueryParam:
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.

View File

@@ -143,7 +143,7 @@ class OracleDB:
data = None
return data
async def execute(self, sql: str, data: list | dict = None):
async def execute(self, sql: str, data: Union[list, dict] = None):
# logger.info("go into OracleDB execute method")
try:
async with self.pool.acquire() as connection:

View File

@@ -4,8 +4,7 @@ import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any, Optional
from dataclasses import dataclass
from typing import List, Dict, Callable, Any, Union
import aioboto3
import aiohttp
@@ -37,6 +36,13 @@ from .utils import (
get_best_cached_response,
)
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -397,7 +403,8 @@ async def ollama_model_if_cache(
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
@@ -422,7 +429,31 @@ async def ollama_model_if_cache(
return cached_response
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
return inner()
else:
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
result = response["message"]["content"]
# Save to cache
@@ -697,7 +728,7 @@ async def hf_model_complete(
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"

View File

@@ -536,9 +536,10 @@ async def kg_query(
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
mode=query_param.mode,
)
if len(response) > len(sys_prompt):
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")