diff --git a/README-zh.md b/README-zh.md index 66690ee8..5300b2cf 100644 --- a/README-zh.md +++ b/README-zh.md @@ -415,7 +415,7 @@ rag = LightRAG( embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, - func=lambda texts: ollama_embedding( + func=lambda texts: ollama_embed( texts, embed_model="nomic-embed-text" ) diff --git a/README.md b/README.md index 449880f2..12e18f0d 100644 --- a/README.md +++ b/README.md @@ -447,7 +447,7 @@ rag = LightRAG( embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, - func=lambda texts: ollama_embedding( + func=lambda texts: ollama_embed( texts, embed_model="nomic-embed-text" ) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 21ae9a67..7668be44 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -31,6 +31,7 @@ from lightrag.api import __api_version__ import numpy as np from typing import Union +from lightrag.utils import logger @retry( @@ -52,7 +53,7 @@ async def _ollama_model_if_cache( kwargs.pop("max_tokens", None) # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) - timeout = kwargs.pop("timeout", None) + timeout = kwargs.pop("timeout", None) or 300 # Default timeout 300s kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) headers = { @@ -61,32 +62,59 @@ async def _ollama_model_if_cache( } if api_key: headers["Authorization"] = f"Bearer {api_key}" + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) + + try: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) - response = await ollama_client.chat(model=model, messages=messages, **kwargs) - if stream: - """cannot cache stream response and process reasoning""" + response = await ollama_client.chat(model=model, messages=messages, **kwargs) + if stream: + """cannot cache stream response and process reasoning""" - async def inner(): - async for chunk in response: - yield chunk["message"]["content"] + async def inner(): + try: + async for chunk in response: + yield chunk["message"]["content"] + except Exception as e: + logger.error(f"Error in stream response: {str(e)}") + raise + finally: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client for streaming") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client: {close_error}") - return inner() - else: - model_response = response["message"]["content"] + return inner() + else: + model_response = response["message"]["content"] - """ - If the model also wraps its thoughts in a specific tag, - this information is not needed for the final - response and can simply be trimmed. - """ + """ + If the model also wraps its thoughts in a specific tag, + this information is not needed for the final + response and can simply be trimmed. + """ - return model_response + return model_response + except Exception as e: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after exception") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after exception: {close_error}") + raise e + finally: + if not stream: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client for non-streaming response") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client in finally block: {close_error}") async def ollama_model_complete( @@ -105,19 +133,6 @@ async def ollama_model_complete( ) -async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: - """ - Deprecated in favor of `embed`. - """ - embed_text = [] - ollama_client = ollama.Client(**kwargs) - for text in texts: - data = ollama_client.embeddings(model=embed_model, prompt=text) - embed_text.append(data["embedding"]) - - return embed_text - - async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) headers = { @@ -125,8 +140,27 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: "User-Agent": f"LightRAG/{__api_version__}", } if api_key: - headers["Authorization"] = api_key - kwargs["headers"] = headers - ollama_client = ollama.Client(**kwargs) - data = ollama_client.embed(model=embed_model, input=texts) - return np.array(data["embeddings"]) + headers["Authorization"] = f"Bearer {api_key}" + + host = kwargs.pop("host", None) + timeout = kwargs.pop("timeout", None) or 90 # Default time out 90s + + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) + + try: + data = await ollama_client.embed(model=embed_model, input=texts) + return np.array(data["embeddings"]) + except Exception as e: + logger.error(f"Error in ollama_embed: {str(e)}") + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after exception in embed") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after exception in embed: {close_error}") + raise e + finally: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after embed") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after embed: {close_error}")