diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index 1a3ff144..56642185 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -11,6 +11,7 @@ net = Network(height="100vh", notebook=True) # Convert NetworkX graph to Pyvis network net.from_nx(G) + # Add colors and title to nodes for node in net.nodes: node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) diff --git a/examples/lightrag_api_ollama_demo.py b/examples/lightrag_api_ollama_demo.py new file mode 100644 index 00000000..36df1262 --- /dev/null +++ b/examples/lightrag_api_ollama_demo.py @@ -0,0 +1,164 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from pydantic import BaseModel +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_embedding, ollama_model_complete +from lightrag.utils import EmbeddingFunc +from typing import Optional +import asyncio +import nest_asyncio +import aiofiles + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR = "index_default" +app = FastAPI(title="LightRAG API", description="API for RAG operations") + +DEFAULT_INPUT_FILE = "book.txt" +INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}") +print(f"INPUT_FILE: {INPUT_FILE}") + +# Configure working directory +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") + + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="gemma2:9b", + llm_model_max_async=4, + llm_model_max_token_size=8192, + llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 8192}}, + embedding_func=EmbeddingFunc( + embedding_dim=768, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="nomic-embed-text", host="http://localhost:11434" + ), + ), +) + + +# Data models +class QueryRequest(BaseModel): + query: str + mode: str = "hybrid" + only_need_context: bool = False + + +class InsertRequest(BaseModel): + text: str + + +class Response(BaseModel): + status: str + data: Optional[str] = None + message: Optional[str] = None + + +# API routes +@app.post("/query", response_model=Response) +async def query_endpoint(request: QueryRequest): + try: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: rag.query( + request.query, + param=QueryParam( + mode=request.mode, only_need_context=request.only_need_context + ), + ), + ) + return Response(status="success", data=result) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# insert by text +@app.post("/insert", response_model=Response) +async def insert_endpoint(request: InsertRequest): + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(request.text)) + return Response(status="success", message="Text inserted successfully") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# insert by file in payload +@app.post("/insert_file", response_model=Response) +async def insert_file(file: UploadFile = File(...)): + try: + file_content = await file.read() + # Read file content + try: + content = file_content.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, try other encodings + content = file_content.decode("gbk") + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {file.filename} inserted successfully", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# insert by local default file +@app.post("/insert_default_file", response_model=Response) +@app.get("/insert_default_file", response_model=Response) +async def insert_default_file(): + try: + # Read file content from book.txt + async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file: + content = await file.read() + print(f"read input file {INPUT_FILE} successfully") + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {INPUT_FILE} inserted successfully", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8020) + +# Usage example +# To run the server, use the following command in your terminal: +# python lightrag_api_openai_compatible_demo.py + +# Example requests: +# 1. Query: +# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' + +# 2. Insert text: +# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' + +# 3. Insert file: +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' + +# 4. Health check: +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/lightrag/llm.py b/lightrag/llm.py index 8d887c00..568fce04 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -632,7 +632,7 @@ async def jina_embedding( url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url headers = { "Content-Type": "application/json", - "Authorization": f"""Bearer {os.environ["JINA_API_KEY"]}""", + "Authorization": f"Bearer {os.environ['JINA_API_KEY']}", } data = { "model": "jina-embeddings-v3", diff --git a/lightrag/operate.py b/lightrag/operate.py index feaec27d..468f4b2f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -222,7 +222,7 @@ async def _merge_edges_then_upsert( }, ) description = await _handle_entity_relation_summary( - (src_id, tgt_id), description, global_config + f"({src_id}, {tgt_id})", description, global_config ) await knowledge_graph_inst.upsert_edge( src_id, @@ -572,7 +572,6 @@ async def kg_query( mode=query_param.mode, ), ) - return response @@ -990,23 +989,37 @@ async def _find_related_text_unit_from_relationships( for index, unit_list in enumerate(text_units): for c_id in unit_list: if c_id not in all_text_units_lookup: - all_text_units_lookup[c_id] = { - "data": await text_chunks_db.get_by_id(c_id), - "order": index, - } + chunk_data = await text_chunks_db.get_by_id(c_id) + # Only store valid data + if chunk_data is not None and "content" in chunk_data: + all_text_units_lookup[c_id] = { + "data": chunk_data, + "order": index, + } - if any([v is None for v in all_text_units_lookup.values()]): - logger.warning("Text chunks are missing, maybe the storage is damaged") - all_text_units = [ - {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None - ] + if not all_text_units_lookup: + logger.warning("No valid text chunks found") + return [] + + all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()] all_text_units = sorted(all_text_units, key=lambda x: x["order"]) - all_text_units = truncate_list_by_token_size( - all_text_units, + + # Ensure all text chunks have content + valid_text_units = [ + t for t in all_text_units if t["data"] is not None and "content" in t["data"] + ] + + if not valid_text_units: + logger.warning("No valid text chunks after filtering") + return [] + + truncated_text_units = truncate_list_by_token_size( + valid_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) - all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] + + all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return all_text_units @@ -1050,24 +1063,43 @@ async def naive_query( results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] + chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) + # Filter out invalid chunks + valid_chunks = [ + chunk for chunk in chunks if chunk is not None and "content" in chunk + ] + + if not valid_chunks: + logger.warning("No valid chunks found after filtering") + return PROMPTS["fail_response"] + maybe_trun_chunks = truncate_list_by_token_size( - chunks, + valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, ) + + if not maybe_trun_chunks: + logger.warning("No chunks left after truncation") + return PROMPTS["fail_response"] + logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + if query_param.only_need_context: return section + sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type ) + if query_param.only_need_prompt: return sys_prompt + response = await use_model_func( query, system_prompt=sys_prompt,