Merge branch 'main' into graph-viewer-webui

This commit is contained in:
ArnoChen
2025-02-13 04:42:57 +08:00
9 changed files with 191 additions and 6 deletions

View File

@@ -158,7 +158,7 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -176,7 +176,7 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -269,7 +269,8 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -0,0 +1,84 @@
# pip install -q -U google-genai to use gemini as a client
import os
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)

View File

@@ -130,7 +130,7 @@ if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"

View File

@@ -227,6 +227,14 @@ class DocStatusStorage(BaseKVStorage):
"""Get all pending documents"""
raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert."""
await self.upsert(data)

View File

@@ -16,7 +16,13 @@ from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from ..base import BaseGraphStorage, BaseKVStorage
from ..base import (
BaseGraphStorage,
BaseKVStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
@@ -39,7 +45,8 @@ class MongoKVStorage(BaseKVStorage):
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
@@ -77,6 +84,82 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop()
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as doc status {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
for k, v in data.items():
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
async def drop(self) -> None:
"""Drop the collection"""
await self._data.drop()
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
result = list(self._data.aggregate(pipeline))
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
result = list(self._data.find({"status": status.value}))
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
content_summary=doc.get("content_summary"),
content_length=doc["content_length"],
status=doc["status"],
created_at=doc.get("created_at"),
updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1),
)
for doc in result
}
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""

View File

@@ -495,6 +495,14 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!")

View File

@@ -46,6 +46,7 @@ STORAGES = {
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",