diff --git a/examples/lightrag_api_ollama_demo.py b/examples/lightrag_api_ollama_demo.py index 634264d3..079e9935 100644 --- a/examples/lightrag_api_ollama_demo.py +++ b/examples/lightrag_api_ollama_demo.py @@ -158,7 +158,7 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' # 3. Insert file: -# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt" # 4. Health check: # curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index e2d63e41..68ccfe95 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -176,7 +176,7 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' # 3. Insert file: -# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt" # 4. Health check: # curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py index 602ca900..6162a300 100644 --- a/examples/lightrag_api_oracle_demo.py +++ b/examples/lightrag_api_oracle_demo.py @@ -269,7 +269,8 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' # 3. Insert file: -# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt" + # 4. Health check: # curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_gemini_demo.py b/examples/lightrag_gemini_demo.py new file mode 100644 index 00000000..32732ba8 --- /dev/null +++ b/examples/lightrag_gemini_demo.py @@ -0,0 +1,84 @@ +# pip install -q -U google-genai to use gemini as a client + +import os +import numpy as np +from google import genai +from google.genai import types +from dotenv import load_dotenv +from lightrag.utils import EmbeddingFunc +from lightrag import LightRAG, QueryParam +from sentence_transformers import SentenceTransformer + +load_dotenv() +gemini_api_key = os.getenv("GEMINI_API_KEY") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + # 1. Initialize the GenAI Client with your Gemini API Key + client = genai.Client(api_key=gemini_api_key) + + # 2. Combine prompts: system prompt, history, and user prompt + if history_messages is None: + history_messages = [] + + combined_prompt = "" + if system_prompt: + combined_prompt += f"{system_prompt}\n" + + for msg in history_messages: + # Each msg is expected to be a dict: {"role": "...", "content": "..."} + combined_prompt += f"{msg['role']}: {msg['content']}\n" + + # Finally, add the new user prompt + combined_prompt += f"user: {prompt}" + + # 3. Call the Gemini model + response = client.models.generate_content( + model="gemini-1.5-flash", + contents=[combined_prompt], + config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1), + ) + + # 4. Return the response text + return response.text + + +async def embedding_func(texts: list[str]) -> np.ndarray: + model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = model.encode(texts, convert_to_numpy=True) + return embeddings + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=embedding_func, + ), +) + +file_path = "story.txt" +with open(file_path, "r") as file: + text = file.read() + +rag.insert(text) + +response = rag.query( + query="What is the main theme of the story?", + param=QueryParam(mode="hybrid", top_k=5, response_type="single line"), +) + +print(response) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f8f07ea1..42a597c2 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -130,7 +130,7 @@ if mongo_uri: os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_DATABASE"] = mongo_database rag_storage_config.KV_STORAGE = "MongoKVStorage" - rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" + rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage" if mongo_graph: rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage" diff --git a/lightrag/base.py b/lightrag/base.py index bd79d990..3702b49e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -227,6 +227,14 @@ class DocStatusStorage(BaseKVStorage): """Get all pending documents""" raise NotImplementedError + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + raise NotImplementedError + + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all procesed documents""" + raise NotImplementedError + async def update_doc_status(self, data: dict[str, Any]) -> None: """Updates the status of a document. By default, it calls upsert.""" await self.upsert(data) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 4f919ecd..8662d005 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -16,7 +16,13 @@ from typing import Any, List, Tuple, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient -from ..base import BaseGraphStorage, BaseKVStorage +from ..base import ( + BaseGraphStorage, + BaseKVStorage, + DocProcessingStatus, + DocStatus, + DocStatusStorage, +) from ..namespace import NameSpace, is_namespace from ..utils import logger @@ -39,7 +45,8 @@ class MongoKVStorage(BaseKVStorage): async def filter_keys(self, data: set[str]) -> set[str]: existing_ids = [ - str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1}) + str(x["_id"]) + for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) ] return set([s for s in data if s not in existing_ids]) @@ -77,6 +84,82 @@ class MongoKVStorage(BaseKVStorage): await self._data.drop() +@dataclass +class MongoDocStatusStorage(DocStatusStorage): + def __post_init__(self): + client = MongoClient( + os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + ) + database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) + self._data = database.get_collection(self.namespace) + logger.info(f"Use MongoDB as doc status {self.namespace}") + + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + return self._data.find_one({"_id": id}) + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + return list(self._data.find({"_id": {"$in": ids}})) + + async def filter_keys(self, data: set[str]) -> set[str]: + existing_ids = [ + str(x["_id"]) + for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + ] + return set([s for s in data if s not in existing_ids]) + + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + for k, v in data.items(): + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + data[k]["_id"] = k + + async def drop(self) -> None: + """Drop the collection""" + await self._data.drop() + + async def get_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status""" + pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] + result = list(self._data.aggregate(pipeline)) + counts = {} + for doc in result: + counts[doc["_id"]] = doc["count"] + return counts + + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents by status""" + result = list(self._data.find({"status": status.value})) + return { + doc["_id"]: DocProcessingStatus( + content=doc["content"], + content_summary=doc.get("content_summary"), + content_length=doc["content_length"], + status=doc["status"], + created_at=doc.get("created_at"), + updated_at=doc.get("updated_at"), + chunks_count=doc.get("chunks_count", -1), + ) + for doc in result + } + + async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all failed documents""" + return await self.get_docs_by_status(DocStatus.FAILED) + + async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: + """Get all pending documents""" + return await self.get_docs_by_status(DocStatus.PENDING) + + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + return await self.get_docs_by_status(DocStatus.PROCESSING) + + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all procesed documents""" + return await self.get_docs_by_status(DocStatus.PROCESSED) + + @dataclass class MongoGraphStorage(BaseGraphStorage): """ diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 79496abd..00ab9f0b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -495,6 +495,14 @@ class PGDocStatusStorage(DocStatusStorage): """Get all pending documents""" return await self.get_docs_by_status(DocStatus.PENDING) + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + return await self.get_docs_by_status(DocStatus.PROCESSING) + + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all procesed documents""" + return await self.get_docs_by_status(DocStatus.PROCESSED) + async def index_done_callback(self): """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" logger.info("Doc status had been saved into postgresql db!") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b9f5f293..726e7512 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -46,6 +46,7 @@ STORAGES = { "OracleVectorDBStorage": ".kg.oracle_impl", "MilvusVectorDBStorge": ".kg.milvus_impl", "MongoKVStorage": ".kg.mongo_impl", + "MongoDocStatusStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",