From d517ef9c209b96dc61ff7f3fb860a6f7e2b6d714 Mon Sep 17 00:00:00 2001 From: Soumil Date: Mon, 21 Oct 2024 18:34:43 +0100 Subject: [PATCH] added a class to use multiple models --- lightrag/llm.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..d820766d 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -13,6 +13,8 @@ from tenacity import ( ) from transformers import AutoTokenizer, AutoModelForCausalLM import torch +from pydantic import BaseModel, Field +from typing import List, Dict, Callable, Any from .base import BaseKVStorage from .utils import compute_args_hash, wrap_embedding_func_with_attrs @@ -423,6 +425,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: return embed_text +class Model(BaseModel): + """ + This is a Pydantic model class named 'Model' that is used to define a custom language model. + + Attributes: + gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. + The function should take any argument and return a string. + kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. + This could include parameters such as the model name, API key, etc. + + Example usage: + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) + + In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. + The 'kwargs' dictionary contains the model name and API key to be passed to the function. + """ + + gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string") + kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc") + + class Config: + arbitrary_types_allowed = True + + +class MultiModel(): + """ + Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. + Could also be used for spliting across diffrent models or providers. + + Attributes: + models (List[Model]): A list of language models to be used. + + Usage example: + ```python + models = [ + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), + ] + multi_model = MultiModel(models) + rag = LightRAG( + llm_model_func=multi_model.llm_model_func + / ..other args + ) + ``` + """ + def __init__(self, models: List[Model]): + self._models = models + self._current_model = 0 + + def _next_model(self): + self._current_model = (self._current_model + 1) % len(self._models) + return self._models[self._current_model] + + async def llm_model_func( + self, + prompt, system_prompt=None, history_messages=[], **kwargs + ) -> str: + kwargs.pop("model", None) # stop from overwriting the custom model name + next_model = self._next_model() + args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs) + + return await next_model.gen_func( + **args + ) if __name__ == "__main__": import asyncio