Optimize Ollama LLM driver

This commit is contained in:
yangdx
2025-05-14 01:13:03 +08:00
parent aa36894d6e
commit b836d02cac
3 changed files with 75 additions and 41 deletions

View File

@@ -415,7 +415,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=768, embedding_dim=768,
max_token_size=8192, max_token_size=8192,
func=lambda texts: ollama_embedding( func=lambda texts: ollama_embed(
texts, texts,
embed_model="nomic-embed-text" embed_model="nomic-embed-text"
) )

View File

@@ -447,7 +447,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=768, embedding_dim=768,
max_token_size=8192, max_token_size=8192,
func=lambda texts: ollama_embedding( func=lambda texts: ollama_embed(
texts, texts,
embed_model="nomic-embed-text" embed_model="nomic-embed-text"
) )

View File

@@ -31,6 +31,7 @@ from lightrag.api import __api_version__
import numpy as np import numpy as np
from typing import Union from typing import Union
from lightrag.utils import logger
@retry( @retry(
@@ -52,7 +53,7 @@ async def _ollama_model_if_cache(
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json # kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None) or 300 # Default timeout 300s
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers = { headers = {
@@ -61,32 +62,59 @@ async def _ollama_model_if_cache(
} }
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = []
if system_prompt: try:
messages.append({"role": "system", "content": system_prompt}) messages = []
messages.extend(history_messages) if system_prompt:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await ollama_client.chat(model=model, messages=messages, **kwargs) response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream: if stream:
"""cannot cache stream response and process reasoning""" """cannot cache stream response and process reasoning"""
async def inner(): async def inner():
async for chunk in response: try:
yield chunk["message"]["content"] async for chunk in response:
yield chunk["message"]["content"]
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
raise
finally:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client for streaming")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client: {close_error}")
return inner() return inner()
else: else:
model_response = response["message"]["content"] model_response = response["message"]["content"]
""" """
If the model also wraps its thoughts in a specific tag, If the model also wraps its thoughts in a specific tag,
this information is not needed for the final this information is not needed for the final
response and can simply be trimmed. response and can simply be trimmed.
""" """
return model_response return model_response
except Exception as e:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client after exception")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client after exception: {close_error}")
raise e
finally:
if not stream:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client for non-streaming response")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client in finally block: {close_error}")
async def ollama_model_complete( async def ollama_model_complete(
@@ -105,19 +133,6 @@ async def ollama_model_complete(
) )
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
Deprecated in favor of `embed`.
"""
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers = { headers = {
@@ -125,8 +140,27 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"User-Agent": f"LightRAG/{__api_version__}", "User-Agent": f"LightRAG/{__api_version__}",
} }
if api_key: if api_key:
headers["Authorization"] = api_key headers["Authorization"] = f"Bearer {api_key}"
kwargs["headers"] = headers
ollama_client = ollama.Client(**kwargs) host = kwargs.pop("host", None)
data = ollama_client.embed(model=embed_model, input=texts) timeout = kwargs.pop("timeout", None) or 90 # Default time out 90s
return np.array(data["embeddings"])
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
data = await ollama_client.embed(model=embed_model, input=texts)
return np.array(data["embeddings"])
except Exception as e:
logger.error(f"Error in ollama_embed: {str(e)}")
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client after exception in embed")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client after exception in embed: {close_error}")
raise e
finally:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client after embed")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client after embed: {close_error}")