Merge pull request #87 from Soumil32/main
Added a class to use multiple models
This commit is contained in:
@@ -19,6 +19,8 @@ from tenacity import (
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
|
||||
@@ -554,6 +556,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
||||
|
||||
return embed_text
|
||||
|
||||
class Model(BaseModel):
|
||||
"""
|
||||
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
||||
|
||||
Attributes:
|
||||
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
||||
The function should take any argument and return a string.
|
||||
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
||||
This could include parameters such as the model name, API key, etc.
|
||||
|
||||
Example usage:
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
||||
|
||||
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
||||
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
||||
"""
|
||||
|
||||
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
|
||||
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class MultiModel():
|
||||
"""
|
||||
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
||||
Could also be used for spliting across diffrent models or providers.
|
||||
|
||||
Attributes:
|
||||
models (List[Model]): A list of language models to be used.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
models = [
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
||||
]
|
||||
multi_model = MultiModel(models)
|
||||
rag = LightRAG(
|
||||
llm_model_func=multi_model.llm_model_func
|
||||
/ ..other args
|
||||
)
|
||||
```
|
||||
"""
|
||||
def __init__(self, models: List[Model]):
|
||||
self._models = models
|
||||
self._current_model = 0
|
||||
|
||||
def _next_model(self):
|
||||
self._current_model = (self._current_model + 1) % len(self._models)
|
||||
return self._models[self._current_model]
|
||||
|
||||
async def llm_model_func(
|
||||
self,
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||
next_model = self._next_model()
|
||||
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
|
||||
|
||||
return await next_model.gen_func(
|
||||
**args
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
Reference in New Issue
Block a user