Merge pull request #25 from HKUDS/test

Add Ollama models support
This commit is contained in:
zrguo
2024-10-16 17:48:17 +08:00
committed by GitHub
7 changed files with 167 additions and 11 deletions

View File

@@ -7,7 +7,6 @@
<p> <p>
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a> <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
<img src="https://badges.pufler.dev/visits/hkuds/lightrag?style=flat-square&logo=github">
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' /> <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
</p> </p>
<p> <p>
@@ -21,7 +20,8 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div> </div>
## 🎉 News ## 🎉 News
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports Hugging Face models! - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-ollama-models)!
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-hugging-face-models)!
## Install ## Install
@@ -37,7 +37,7 @@ pip install lightrag-hku
``` ```
## Quick Start ## Quick Start
* All the code can be found in the `examples`.
* Set OpenAI API key in environment if using OpenAI models: `export OPENAI_API_KEY="sk-...".` * Set OpenAI API key in environment if using OpenAI models: `export OPENAI_API_KEY="sk-...".`
* Download the demo text "A Christmas Carol by Charles Dickens": * Download the demo text "A Christmas Carol by Charles Dickens":
```bash ```bash
@@ -75,6 +75,42 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
``` ```
### Open AI-like APIs
LightRAG also support Open AI-like chat/embeddings APIs:
```python
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
)
```
### Using Hugging Face Models ### Using Hugging Face Models
If you want to use Hugging Face models, you only need to set LightRAG as follows: If you want to use Hugging Face models, you only need to set LightRAG as follows:
```python ```python
@@ -84,7 +120,7 @@ from transformers import AutoModel, AutoTokenizer
# Initialize LightRAG with Hugging Face model # Initialize LightRAG with Hugging Face model
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=hf_model_complete, # Use Hugging Face complete model for text generation llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
# Use Hugging Face embedding function # Use Hugging Face embedding function
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
@@ -98,11 +134,35 @@ rag = LightRAG(
), ),
) )
``` ```
### Using Ollama Models
If you want to use Ollama models, you only need to set LightRAG as follows:
```python
from lightrag.llm import ollama_model_complete, ollama_embedding
# Initialize LightRAG with Ollama model
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, # Use Ollama model for text generation
llm_model_name='your_model_name', # Your model name
# Use Ollama embedding function
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)
```
### Batch Insert ### Batch Insert
```python ```python
# Batch Insert: Insert multiple texts at once # Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...]) rag.insert(["TEXT1", "TEXT2",...])
``` ```
### Incremental Insert ### Incremental Insert
```python ```python
@@ -186,6 +246,7 @@ Output your evaluation in the following JSON format:
}} }}
}} }}
``` ```
### Overall Performance Table ### Overall Performance Table
| | **Agriculture** | | **CS** | | **Legal** | | **Mix** | | | | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
|----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------| |----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
@@ -212,6 +273,7 @@ Output your evaluation in the following JSON format:
## Reproduce ## Reproduce
All the code can be found in the `./reproduce` directory. All the code can be found in the `./reproduce` directory.
### Step-0 Extract Unique Contexts ### Step-0 Extract Unique Contexts
First, we need to extract unique contexts in the datasets. First, we need to extract unique contexts in the datasets.
```python ```python
@@ -265,6 +327,7 @@ def extract_unique_contexts(input_directory, output_directory):
print("All files have been processed.") print("All files have been processed.")
``` ```
### Step-1 Insert Contexts ### Step-1 Insert Contexts
For the extracted contexts, we insert them into the LightRAG system. For the extracted contexts, we insert them into the LightRAG system.
@@ -286,6 +349,7 @@ def insert_text(rag, file_path):
if retries == max_retries: if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries") print("Insertion failed after exceeding the maximum number of retries")
``` ```
### Step-2 Generate Queries ### Step-2 Generate Queries
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries. We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
@@ -326,8 +390,10 @@ def extract_queries(file_path):
├── examples ├── examples
├── batch_eval.py ├── batch_eval.py
├── generate_query.py ├── generate_query.py
├── lightrag_openai_demo.py ├── lightrag_hf_demo.py
── lightrag_hf_demo.py ── lightrag_ollama_demo.py
├── lightrag_openai_compatible_demo.py
└── lightrag_openai_demo.py
├── lightrag ├── lightrag
├── __init__.py ├── __init__.py
├── base.py ├── base.py

View File

@@ -0,0 +1,40 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name='your_model_name',
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam from .lightrag import LightRAG, QueryParam
__version__ = "0.0.5" __version__ = "0.0.6"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -6,7 +6,7 @@ from functools import partial
from typing import Type, cast, Any from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities,

View File

@@ -1,5 +1,6 @@
import os import os
import numpy as np import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
from tenacity import ( from tenacity import (
retry, retry,
@@ -92,6 +93,34 @@ async def hf_model_if_cache(
) )
return response_text return response_text
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
ollama_client = ollama.AsyncClient()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
return result
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
@@ -116,8 +145,6 @@ async def gpt_4o_mini_complete(
**kwargs, **kwargs,
) )
async def hf_model_complete( async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -130,6 +157,18 @@ async def hf_model_complete(
**kwargs, **kwargs,
) )
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@@ -154,6 +193,13 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy() return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View File

@@ -6,3 +6,7 @@ nano-vectordb
hnswlib hnswlib
xxhash xxhash
tenacity tenacity
transformers
torch
ollama
accelerate

View File

@@ -1,6 +1,6 @@
import setuptools import setuptools
with open("README.md", "r") as fh: with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()