Compare commits

..

3 Commits

Author SHA1 Message Date
87d87f4ed6 add missing branch
All checks were successful
Build and Push Docker Image / build-and-push (push) Successful in 11m3s
2025-05-22 04:48:24 +08:00
645b294cce add build script 2025-05-22 04:46:57 +08:00
8916f8a912 feat: add delete method for mongo storage implement
All checks were successful
Linting and Formatting / lint-and-format (push) Successful in 3m47s
2025-05-22 04:41:52 +08:00
5 changed files with 17 additions and 210 deletions

View File

@@ -53,6 +53,7 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs,
) )
return response return response
except Exception as e: except Exception as e:

View File

@@ -1,155 +0,0 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.llama_index_impl import (
llama_index_complete_if_cache,
llama_index_embed,
)
from lightrag.utils import EmbeddingFunc
from llama_index.llms.litellm import LiteLLM
from llama_index.embeddings.litellm import LiteLLMEmbedding
import asyncio
import nest_asyncio
nest_asyncio.apply()
from lightrag.kg.shared_storage import initialize_pipeline_status
# Configure working directory
WORKING_DIR = "./index_default"
print(f"WORKING_DIR: {WORKING_DIR}")
# Model configuration
LLM_MODEL = os.environ.get("LLM_MODEL", "gemma-3-4b")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "arctic-embed")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
# LiteLLM configuration
LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000")
print(f"LITELLM_URL: {LITELLM_URL}")
LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-4JdvGFKqSA3S0k_5p0xufw")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# Initialize LLM function
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize LiteLLM if not in kwargs
if "llm_instance" not in kwargs:
llm_instance = LiteLLM(
model=f"openai/{LLM_MODEL}", # Format: "provider/model_name"
api_base=LITELLM_URL,
api_key=LITELLM_KEY,
temperature=0.7,
)
kwargs["llm_instance"] = llm_instance
chat_kwargs = {}
chat_kwargs["litellm_params"] = {
"metadata": {
"opik": {
"project_name": "lightrag_llamaindex_litellm_opik_demo",
"tags": ["lightrag", "litellm"],
}
}
}
response = await llama_index_complete_if_cache(
kwargs["llm_instance"],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
chat_kwargs=chat_kwargs,
)
return response
except Exception as e:
print(f"LLM request failed: {str(e)}")
raise
# Initialize embedding function
async def embedding_func(texts):
try:
embed_model = LiteLLMEmbedding(
model_name=f"openai/{EMBEDDING_MODEL}",
api_base=LITELLM_URL,
api_key=LITELLM_KEY,
)
return await llama_index_embed(texts, embed_model=embed_model)
except Exception as e:
print(f"Embedding failed: {str(e)}")
raise
# Get embedding dimension
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"embedding_dim={embedding_dim}")
return embedding_dim
async def initialize_rag():
embedding_dimension = await get_embedding_dim()
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
# Insert example text
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Test different query modes
print("\nNaive Search:")
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
print("\nLocal Search:")
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
print("\nGlobal Search:")
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
print("\nHybrid Search:")
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
if __name__ == "__main__":
main()

View File

@@ -84,30 +84,22 @@ class InsertTextRequest(BaseModel):
Attributes: Attributes:
text: The text content to be inserted into the RAG system text: The text content to be inserted into the RAG system
file_source: Source of the text (optional)
""" """
text: str = Field( text: str = Field(
min_length=1, min_length=1,
description="The text to insert", description="The text to insert",
) )
file_source: str = Field(default=None, min_length=0, description="File Source")
@field_validator("text", mode="after") @field_validator("text", mode="after")
@classmethod @classmethod
def strip_text_after(cls, text: str) -> str: def strip_after(cls, text: str) -> str:
return text.strip() return text.strip()
@field_validator("file_source", mode="after")
@classmethod
def strip_source_after(cls, file_source: str) -> str:
return file_source.strip()
class Config: class Config:
json_schema_extra = { json_schema_extra = {
"example": { "example": {
"text": "This is a sample text to be inserted into the RAG system.", "text": "This is a sample text to be inserted into the RAG system."
"file_source": "Source of the text (optional)",
} }
} }
@@ -117,37 +109,25 @@ class InsertTextsRequest(BaseModel):
Attributes: Attributes:
texts: List of text contents to be inserted into the RAG system texts: List of text contents to be inserted into the RAG system
file_sources: Sources of the texts (optional)
""" """
texts: list[str] = Field( texts: list[str] = Field(
min_length=1, min_length=1,
description="The texts to insert", description="The texts to insert",
) )
file_sources: list[str] = Field(
default=None, min_length=0, description="Sources of the texts"
)
@field_validator("texts", mode="after") @field_validator("texts", mode="after")
@classmethod @classmethod
def strip_texts_after(cls, texts: list[str]) -> list[str]: def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts] return [text.strip() for text in texts]
@field_validator("file_sources", mode="after")
@classmethod
def strip_sources_after(cls, file_sources: list[str]) -> list[str]:
return [file_source.strip() for file_source in file_sources]
class Config: class Config:
json_schema_extra = { json_schema_extra = {
"example": { "example": {
"texts": [ "texts": [
"This is the first text to be inserted.", "This is the first text to be inserted.",
"This is the second text to be inserted.", "This is the second text to be inserted.",
], ]
"file_sources": [
"First file source (optional)",
],
} }
} }
@@ -676,25 +656,16 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def pipeline_index_texts( async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
rag: LightRAG, texts: List[str], file_sources: List[str] = None
):
"""Index a list of texts """Index a list of texts
Args: Args:
rag: LightRAG instance rag: LightRAG instance
texts: The texts to index texts: The texts to index
file_sources: Sources of the texts
""" """
if not texts: if not texts:
return return
if file_sources is not None: await rag.apipeline_enqueue_documents(texts)
if len(file_sources) != 0 and len(file_sources) != len(texts):
[
file_sources.append("unknown_source")
for _ in range(len(file_sources), len(texts))
]
await rag.apipeline_enqueue_documents(input=texts, file_paths=file_sources)
await rag.apipeline_process_enqueue_documents() await rag.apipeline_process_enqueue_documents()
@@ -845,12 +816,7 @@ def create_document_routes(
HTTPException: If an error occurs during text processing (500). HTTPException: If an error occurs during text processing (500).
""" """
try: try:
background_tasks.add_task( background_tasks.add_task(pipeline_index_texts, rag, [request.text])
pipeline_index_texts,
rag,
[request.text],
file_sources=[request.file_source],
)
return InsertResponse( return InsertResponse(
status="success", status="success",
message="Text successfully received. Processing will continue in background.", message="Text successfully received. Processing will continue in background.",
@@ -885,12 +851,7 @@ def create_document_routes(
HTTPException: If an error occurs during text processing (500). HTTPException: If an error occurs during text processing (500).
""" """
try: try:
background_tasks.add_task( background_tasks.add_task(pipeline_index_texts, rag, request.texts)
pipeline_index_texts,
rag,
request.texts,
file_sources=request.file_sources,
)
return InsertResponse( return InsertResponse(
status="success", status="success",
message="Text successfully received. Processing will continue in background.", message="Text successfully received. Processing will continue in background.",

View File

@@ -78,10 +78,6 @@ class QueryRequest(BaseModel):
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
) )
ids: list[str] | None = Field(
default=None, description="List of ids to filter the results."
)
user_prompt: Optional[str] = Field( user_prompt: Optional[str] = Field(
default=None, default=None,
description="User-provided prompt for the query. If provided, this will be used instead of the default value from prompt template.", description="User-provided prompt for the query. If provided, this will be used instead of the default value from prompt template.",

View File

@@ -95,7 +95,7 @@ async def llama_index_complete_if_cache(
prompt: str, prompt: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history_messages: List[dict] = [], history_messages: List[dict] = [],
chat_kwargs={}, **kwargs,
) -> str: ) -> str:
"""Complete the prompt using LlamaIndex.""" """Complete the prompt using LlamaIndex."""
try: try:
@@ -122,9 +122,13 @@ async def llama_index_complete_if_cache(
# Add current prompt # Add current prompt
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt)) formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
response: ChatResponse = await model.achat( # Get LLM instance from kwargs
messages=formatted_messages, **chat_kwargs if "llm_instance" not in kwargs:
) raise ValueError("llm_instance must be provided in kwargs")
llm = kwargs["llm_instance"]
# Get response
response: ChatResponse = await llm.achat(messages=formatted_messages)
# In newer versions, the response is in message.content # In newer versions, the response is in message.content
content = response.message.content content = response.message.content