From e7b6c55beeb1854aa0b32d634d8a2719f23a65da Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Tue, 11 Feb 2025 23:43:42 +0530 Subject: [PATCH] Integrate gemini client into Lightrag --- examples/lightrag_gemini_demo.py | 82 ++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 examples/lightrag_gemini_demo.py diff --git a/examples/lightrag_gemini_demo.py b/examples/lightrag_gemini_demo.py new file mode 100644 index 00000000..ff2eadac --- /dev/null +++ b/examples/lightrag_gemini_demo.py @@ -0,0 +1,82 @@ +# pip install -q -U google-genai to use gemini as a client + +import os +import numpy as np +from google import genai +from google.genai import types +from dotenv import load_dotenv +from lightrag.utils import EmbeddingFunc +from lightrag import LightRAG, QueryParam +from sentence_transformers import SentenceTransformer + +load_dotenv() +gemini_api_key = os.getenv("GEMINI_API_KEY") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + # 1. Initialize the GenAI Client with your Gemini API Key + client = genai.Client(api_key=gemini_api_key) + + # 2. Combine prompts: system prompt, history, and user prompt + if history_messages is None: + history_messages = [] + + combined_prompt = "" + if system_prompt: + combined_prompt += f"{system_prompt}\n" + + for msg in history_messages: + # Each msg is expected to be a dict: {"role": "...", "content": "..."} + combined_prompt += f"{msg['role']}: {msg['content']}\n" + + # Finally, add the new user prompt + combined_prompt += f"user: {prompt}" + + # 3. Call the Gemini model + response = client.models.generate_content( + model="gemini-1.5-flash", + contents=[combined_prompt], + config=types.GenerateContentConfig( + max_output_tokens=500, + temperature=0.1 + ) + ) + + # 4. Return the response text + return response.text + +async def embedding_func(texts: list[str]) -> np.ndarray: + model = SentenceTransformer('all-MiniLM-L6-v2') + embeddings = model.encode(texts, convert_to_numpy=True) + return embeddings + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=embedding_func, + ), +) + +file_path = "story.txt" +with open(file_path, 'r') as file: + text = file.read() + +rag.insert(text) + +response = rag.query(query="What is the main theme of the story?", param=QueryParam(mode="hybrid", top_k=5, response_type="single line")) + +print (response) \ No newline at end of file