Merge branch 'main' into select-datastore-in-api-server

This commit is contained in:
yangdx
2025-02-13 11:25:52 +08:00
9 changed files with 187 additions and 9 deletions

View File

@@ -158,7 +158,7 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file: # 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check: # 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health" # curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -176,7 +176,7 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file: # 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check: # 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health" # curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -269,7 +269,8 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file: # 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check: # 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health" # curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -0,0 +1,84 @@
# pip install -q -U google-genai to use gemini as a client
import os
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)

View File

@@ -228,7 +228,11 @@ class DocStatusStorage(BaseKVStorage):
raise NotImplementedError raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all documents that are currently being processed""" """Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None: async def update_doc_status(self, data: dict[str, Any]) -> None:

View File

@@ -14,7 +14,14 @@ if not pm.is_installed("motor"):
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
from ..base import BaseGraphStorage, BaseKVStorage
from ..base import (
BaseGraphStorage,
BaseKVStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
@@ -51,7 +58,8 @@ class MongoKVStorage(BaseKVStorage):
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [ existing_ids = [
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1}) str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
] ]
return set([s for s in data if s not in existing_ids]) return set([s for s in data if s not in existing_ids])
@@ -89,6 +97,82 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop() await self._data.drop()
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as doc status {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
for k, v in data.items():
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
async def drop(self) -> None:
"""Drop the collection"""
await self._data.drop()
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
result = list(self._data.aggregate(pipeline))
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
result = list(self._data.find({"status": status.value}))
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
content_summary=doc.get("content_summary"),
content_length=doc["content_length"],
status=doc["status"],
created_at=doc.get("created_at"),
updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1),
)
for doc in result
}
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):
""" """

View File

@@ -493,10 +493,14 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all pending documents""" """Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING) return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> Dict[str, DocProcessingStatus]: async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all documents that are currently being processed""" """Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING) return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self): async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!") logger.info("Doc status had been saved into postgresql db!")

View File

@@ -87,7 +87,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
return [] return []

View File

@@ -137,6 +137,7 @@ STORAGE_ENV_REQUIREMENTS = {
# Document Status Storage Implementations # Document Status Storage Implementations
"JsonDocStatusStorage": [], "JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
} }
# Storage implementation module mapping # Storage implementation module mapping
@@ -151,6 +152,7 @@ STORAGES = {
"OracleVectorDBStorage": ".kg.oracle_impl", "OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl", "MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl", "MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl", "RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",