ollama test

This commit is contained in:
LarFii
2024-10-16 15:15:10 +08:00
parent 8946023b54
commit 92c11179fe
6 changed files with 94 additions and 5 deletions

View File

@@ -0,0 +1,40 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name='your_model_name',
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam from .lightrag import LightRAG, QueryParam
__version__ = "0.0.5" __version__ = "0.0.6"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1,5 +1,6 @@
import os import os
import numpy as np import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
from tenacity import ( from tenacity import (
retry, retry,
@@ -92,6 +93,34 @@ async def hf_model_if_cache(
) )
return response_text return response_text
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
ollama_client = ollama.AsyncClient()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
return result
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
@@ -116,8 +145,6 @@ async def gpt_4o_mini_complete(
**kwargs, **kwargs,
) )
async def hf_model_complete( async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -130,6 +157,18 @@ async def hf_model_complete(
**kwargs, **kwargs,
) )
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@@ -154,6 +193,13 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy() return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View File

@@ -6,3 +6,6 @@ nano-vectordb
hnswlib hnswlib
xxhash xxhash
tenacity tenacity
transformers
torch
ollama

View File

@@ -1,6 +1,6 @@
import setuptools import setuptools
with open("README.md", "r") as fh: with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()