Add ability to passadditional parameters to ollama library like host and timeout
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@ dickens/
|
|||||||
book.txt
|
book.txt
|
||||||
lightrag-dev/
|
lightrag-dev/
|
||||||
.idea/
|
.idea/
|
||||||
dist/
|
dist/
|
||||||
|
.venv/
|
@@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.DEBUG)
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import ollama_model_complete, ollama_embedding
|
from lightrag.llm import ollama_model_complete, ollama_embedding
|
||||||
|
@@ -88,6 +88,7 @@ class LightRAG:
|
|||||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||||
llm_model_max_token_size: int = 32768
|
llm_model_max_token_size: int = 32768
|
||||||
llm_model_max_async: int = 16
|
llm_model_max_async: int = 16
|
||||||
|
llm_model_kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
# storage
|
# storage
|
||||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||||
@@ -154,7 +155,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
partial(self.llm_model_func, hashing_kv=self.llm_response_cache, **self.llm_model_kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert(self, string_or_strings):
|
def insert(self, string_or_strings):
|
||||||
|
@@ -222,8 +222,10 @@ async def ollama_model_if_cache(
|
|||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("max_tokens", None)
|
kwargs.pop("max_tokens", None)
|
||||||
kwargs.pop("response_format", None)
|
kwargs.pop("response_format", None)
|
||||||
|
host = kwargs.pop("host", None)
|
||||||
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient()
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
@@ -415,10 +417,11 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
|||||||
return embeddings.detach().numpy()
|
return embeddings.detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
embed_text = []
|
embed_text = []
|
||||||
|
ollama_client = ollama.Client(**kwargs)
|
||||||
for text in texts:
|
for text in texts:
|
||||||
data = ollama.embeddings(model=embed_model, prompt=text)
|
data = ollama_client.embeddings(model=embed_model, prompt=text)
|
||||||
embed_text.append(data["embedding"])
|
embed_text.append(data["embedding"])
|
||||||
|
|
||||||
return embed_text
|
return embed_text
|
||||||
|
Reference in New Issue
Block a user