Merge remote-tracking branch 'origin/main'

# Conflicts:
#	lightrag/llm.py
#	lightrag/operate.py
This commit is contained in:
yuanxiaobin
2024-12-06 15:06:00 +08:00
6 changed files with 198 additions and 6 deletions

View File

@@ -0,0 +1,140 @@
from datetime import datetime, timezone
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import inspect
import json
from pydantic import BaseModel
from typing import Optional
import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
import nest_asyncio
WORKING_DIR = "./dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:latest",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
# Data models
MODEL_NAME = "LightRAG:latest"
class Message(BaseModel):
role: Optional[str] = None
content: str
class OpenWebUIRequest(BaseModel):
stream: Optional[bool] = None
model: Optional[str] = None
messages: list[Message]
# API routes
@app.get("/")
async def index():
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
@app.get("/ollama/api/version")
async def ollama_version():
return {"version": "0.4.7"}
@app.get("/ollama/api/tags")
async def ollama_tags():
return {
"models": [
{
"name": MODEL_NAME,
"model": MODEL_NAME,
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
"size": 4683087332,
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M",
},
}
]
}
@app.post("/ollama/api/chat")
async def ollama_chat(request: OpenWebUIRequest):
resp = rag.query(
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
)
if inspect.isasyncgen(resp):
async def ollama_resp(chunks):
async for chunk in chunks:
yield (
json.dumps(
{
"model": MODEL_NAME,
"created_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
),
"message": {
"role": "assistant",
"content": chunk,
},
"done": False,
},
ensure_ascii=False,
).encode("utf-8")
+ b"\n"
) # the b"\n" is important
return StreamingResponse(ollama_resp(resp), media_type="application/json")
else:
return resp
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)

View File

@@ -1,4 +1,6 @@
import asyncio
import os import os
import inspect
import logging import logging
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -49,3 +51,20 @@ print(
print( print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
) )
# stream response
resp = rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)
async def print_stream(stream):
async for chunk in stream:
print(chunk, end="", flush=True)
if inspect.isasyncgen(resp):
asyncio.run(print_stream(resp))
else:
print(resp)

View File

@@ -19,6 +19,7 @@ class QueryParam:
only_need_context: bool = False only_need_context: bool = False
only_need_prompt: bool = False only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60 top_k: int = 60
# Number of document chunks to retrieve. # Number of document chunks to retrieve.

View File

@@ -143,7 +143,7 @@ class OracleDB:
data = None data = None
return data return data
async def execute(self, sql: str, data: list | dict = None): async def execute(self, sql: str, data: Union[list, dict] = None):
# logger.info("go into OracleDB execute method") # logger.info("go into OracleDB execute method")
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:

View File

@@ -4,8 +4,7 @@ import json
import os import os
import struct import struct
from functools import lru_cache from functools import lru_cache
from typing import List, Dict, Callable, Any, Optional from typing import List, Dict, Callable, Any, Union
from dataclasses import dataclass
import aioboto3 import aioboto3
import aiohttp import aiohttp
@@ -37,6 +36,13 @@ from .utils import (
get_best_cached_response, get_best_cached_response,
) )
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -397,7 +403,8 @@ async def ollama_model_if_cache(
system_prompt=None, system_prompt=None,
history_messages=[], history_messages=[],
**kwargs, **kwargs,
) -> str: ) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json # kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
@@ -422,7 +429,31 @@ async def ollama_model_if_cache(
return cached_response return cached_response
response = await ollama_client.chat(model=model, messages=messages, **kwargs) response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
return inner()
else:
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
result = response["message"]["content"] result = response["message"]["content"]
# Save to cache # Save to cache
@@ -697,7 +728,7 @@ async def hf_model_complete(
async def ollama_model_complete( async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None) keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction: if keyword_extraction:
kwargs["format"] = "json" kwargs["format"] = "json"

View File

@@ -536,9 +536,10 @@ async def kg_query(
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
stream=query_param.stream,
mode=query_param.mode, mode=query_param.mode,
) )
if len(response) > len(sys_prompt): if isinstance(response, str) and len(response) > len(sys_prompt):
response = ( response = (
response.replace(sys_prompt, "") response.replace(sys_prompt, "")
.replace("user", "") .replace("user", "")