Add ability to passadditional parameters to ollama library like host and timeout

This commit is contained in:
Andrii Lazarchuk
2024-10-21 11:53:06 +00:00
parent f49de420cf
commit 108fc4a1ee
4 changed files with 151 additions and 13 deletions

121
.gitignore vendored Normal file
View File

@@ -0,0 +1,121 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
venv/
ENV/
env.bak/
venv.bak/
*.egg
*.egg-info/
dist/
build/
*.whl
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.env.*
.venv
.venv.*
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyderworkspace
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# Example files
book.txt
dickens/

View File

@@ -1,4 +1,7 @@
import os import os
import logging
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.DEBUG)
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -11,15 +14,17 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, tiktoken_model_name="mistral:7b",
llm_model_name='your_model_name', llm_model_func=ollama_model_complete,
llm_model_name="mistral:7b",
llm_model_max_async=2,
llm_model_kwargs={"host": "http://localhost:11434"},
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=768, embedding_dim=768,
max_token_size=8192, max_token_size=8192,
func=lambda texts: ollama_embedding( func=lambda texts: ollama_embedding(
texts, texts, embed_model="nomic-embed-text", host="http://localhost:11434"
embed_model="nomic-embed-text" ),
)
), ),
) )
@@ -28,13 +33,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -86,6 +86,7 @@ class LightRAG:
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768 llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16 llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
# storage # storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
@@ -158,7 +159,7 @@ class LightRAG:
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache) partial(self.llm_model_func, hashing_kv=self.llm_response_cache, **self.llm_model_kwargs)
) )
def insert(self, string_or_strings): def insert(self, string_or_strings):

View File

@@ -98,8 +98,10 @@ async def ollama_model_if_cache(
) -> str: ) -> str:
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None) kwargs.pop("response_format", None)
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
ollama_client = ollama.AsyncClient() ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -193,10 +195,11 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy() return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
embed_text = [] embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts: for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text) data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"]) embed_text.append(data["embedding"])
return embed_text return embed_text