From 25a2dd41c1e39801f029fcd9fb128b4d8b45356d Mon Sep 17 00:00:00 2001 From: Andrii Lazarchuk Date: Mon, 21 Oct 2024 11:53:06 +0000 Subject: [PATCH] Add ability to passadditional parameters to ollama library like host and timeout --- .gitignore | 3 ++- examples/lightrag_ollama_demo.py | 3 +++ lightrag/lightrag.py | 3 ++- lightrag/llm.py | 9 ++++++--- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 5a41ae32..9ce353de 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ dickens/ book.txt lightrag-dev/ .idea/ -dist/ \ No newline at end of file +dist/ +.venv/ \ No newline at end of file diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index c61b71c0..f968d26e 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -1,4 +1,7 @@ import os +import logging + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.DEBUG) from lightrag import LightRAG, QueryParam from lightrag.llm import ollama_model_complete, ollama_embedding diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5137af42..d4b1eaa1 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -88,6 +88,7 @@ class LightRAG: llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 + llm_model_kwargs: dict = field(default_factory=dict) # storage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage @@ -154,7 +155,7 @@ class LightRAG: ) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( - partial(self.llm_model_func, hashing_kv=self.llm_response_cache) + partial(self.llm_model_func, hashing_kv=self.llm_response_cache, **self.llm_model_kwargs) ) def insert(self, string_or_strings): diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..aa818995 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -222,8 +222,10 @@ async def ollama_model_if_cache( ) -> str: kwargs.pop("max_tokens", None) kwargs.pop("response_format", None) + host = kwargs.pop("host", None) + timeout = kwargs.pop("timeout", None) - ollama_client = ollama.AsyncClient() + ollama_client = ollama.AsyncClient(host=host, timeout=timeout) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -415,10 +417,11 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: return embeddings.detach().numpy() -async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: +async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: embed_text = [] + ollama_client = ollama.Client(**kwargs) for text in texts: - data = ollama.embeddings(model=embed_model, prompt=text) + data = ollama_client.embeddings(model=embed_model, prompt=text) embed_text.append(data["embedding"]) return embed_text