Merge branch 'main' into light-webui
This commit is contained in:
@@ -1 +1,63 @@
|
|||||||
.env
|
# Python-related files and directories
|
||||||
|
__pycache__
|
||||||
|
.cache
|
||||||
|
|
||||||
|
# Virtual environment directories
|
||||||
|
*.venv
|
||||||
|
|
||||||
|
# Env
|
||||||
|
env/
|
||||||
|
*.env*
|
||||||
|
.env_example
|
||||||
|
|
||||||
|
# Distribution / build files
|
||||||
|
site
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.eggs/
|
||||||
|
*.egg-info/
|
||||||
|
*.tgz
|
||||||
|
*.tar.gz
|
||||||
|
|
||||||
|
# Exclude siles and folders
|
||||||
|
*.yml
|
||||||
|
.dockerignore
|
||||||
|
Dockerfile
|
||||||
|
Makefile
|
||||||
|
|
||||||
|
# Exclude other projects
|
||||||
|
/tests
|
||||||
|
/scripts
|
||||||
|
|
||||||
|
# Python version manager file
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# Reports
|
||||||
|
*.coverage/
|
||||||
|
*.log
|
||||||
|
log/
|
||||||
|
*.logfire
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
.cache/
|
||||||
|
.mypy_cache
|
||||||
|
.pytest_cache
|
||||||
|
.ruff_cache
|
||||||
|
.gradio
|
||||||
|
.logfire
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# MacOS-related files
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# VS Code settings (local configuration files)
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
# file
|
||||||
|
TODO.md
|
||||||
|
|
||||||
|
# Exclude Git-related files
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
.gitignore
|
||||||
|
.pre-commit-config.yaml
|
||||||
|
17
.env.example
17
.env.example
@@ -1,19 +1,20 @@
|
|||||||
### Server Configuration
|
### Server Configuration
|
||||||
#HOST=0.0.0.0
|
# HOST=0.0.0.0
|
||||||
#PORT=9621
|
# PORT=9621
|
||||||
#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
||||||
|
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
||||||
|
|
||||||
### Optional SSL Configuration
|
### Optional SSL Configuration
|
||||||
#SSL=true
|
# SSL=true
|
||||||
#SSL_CERTFILE=/path/to/cert.pem
|
# SSL_CERTFILE=/path/to/cert.pem
|
||||||
#SSL_KEYFILE=/path/to/key.pem
|
# SSL_KEYFILE=/path/to/key.pem
|
||||||
|
|
||||||
### Security (empty for no api-key is needed)
|
### Security (empty for no api-key is needed)
|
||||||
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||||
|
|
||||||
### Directory Configuration
|
### Directory Configuration
|
||||||
# WORKING_DIR=./rag_storage
|
# WORKING_DIR=<absolute_path_for_working_dir>
|
||||||
# INPUT_DIR=./inputs
|
# INPUT_DIR=<absolute_path_for_doc_input_dir>
|
||||||
|
|
||||||
### Logging level
|
### Logging level
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
|
79
.gitignore
vendored
79
.gitignore
vendored
@@ -1,26 +1,61 @@
|
|||||||
__pycache__
|
# Python-related files
|
||||||
*.egg-info
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
*.tgz
|
||||||
|
*.tar.gz
|
||||||
|
*.ini # Remove config.ini from repo
|
||||||
|
|
||||||
|
# Virtual Environment
|
||||||
|
.venv/
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
*.env*
|
||||||
|
.env_example
|
||||||
|
|
||||||
|
# Build / Distribution
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
site/
|
||||||
|
|
||||||
|
# Logs / Reports
|
||||||
|
*.log
|
||||||
|
*.logfire
|
||||||
|
*.coverage/
|
||||||
|
log/
|
||||||
|
|
||||||
|
# Caches
|
||||||
|
.cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.gradio/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# IDE / Editor Files
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
.vscode/settings.json
|
||||||
|
|
||||||
|
# Framework-specific files
|
||||||
|
local_neo4jWorkDir/
|
||||||
|
neo4jWorkDir/
|
||||||
|
|
||||||
|
# Data & Storage
|
||||||
|
inputs/
|
||||||
|
rag_storage/
|
||||||
|
examples/input/
|
||||||
|
examples/output/
|
||||||
|
|
||||||
|
# Miscellaneous
|
||||||
|
.DS_Store
|
||||||
|
TODO.md
|
||||||
|
ignore_this.txt
|
||||||
|
*.ignore.*
|
||||||
|
|
||||||
|
# Project-specific files
|
||||||
dickens/
|
dickens/
|
||||||
book.txt
|
book.txt
|
||||||
lightrag-dev/
|
lightrag-dev/
|
||||||
.idea/
|
|
||||||
dist/
|
|
||||||
env/
|
|
||||||
local_neo4jWorkDir/
|
|
||||||
neo4jWorkDir/
|
|
||||||
ignore_this.txt
|
|
||||||
.venv/
|
|
||||||
*.ignore.*
|
|
||||||
.ruff_cache/
|
|
||||||
gui/
|
gui/
|
||||||
*.log
|
|
||||||
.vscode
|
|
||||||
inputs
|
|
||||||
rag_storage
|
|
||||||
.env
|
|
||||||
venv/
|
|
||||||
examples/input/
|
|
||||||
examples/output/
|
|
||||||
.DS_Store
|
|
||||||
#Remove config.ini from repo
|
|
||||||
*.ini
|
|
||||||
|
@@ -237,7 +237,7 @@ rag = LightRAG(
|
|||||||
|
|
||||||
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
||||||
```python
|
```python
|
||||||
from lightrag.llm import hf_model_complete, hf_embedding
|
from lightrag.llm import hf_model_complete, hf_embed
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
@@ -250,7 +250,7 @@ rag = LightRAG(
|
|||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
max_token_size=5000,
|
max_token_size=5000,
|
||||||
func=lambda texts: hf_embedding(
|
func=lambda texts: hf_embed(
|
||||||
texts,
|
texts,
|
||||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
||||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
@@ -428,9 +428,9 @@ And using a routine to process news documents.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
rag = LightRAG(..)
|
rag = LightRAG(..)
|
||||||
await rag.apipeline_enqueue_documents(string_or_strings)
|
await rag.apipeline_enqueue_documents(input)
|
||||||
# Your routine in loop
|
# Your routine in loop
|
||||||
await rag.apipeline_process_enqueue_documents(string_or_strings)
|
await rag.apipeline_process_enqueue_documents(input)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Separate Keyword Extraction
|
### Separate Keyword Extraction
|
||||||
|
@@ -113,7 +113,24 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||||
rag.set_storage_client(db_client=oracle_db)
|
|
||||||
|
for storage in [
|
||||||
|
rag.vector_db_storage_cls,
|
||||||
|
rag.graph_storage_cls,
|
||||||
|
rag.doc_status,
|
||||||
|
rag.full_docs,
|
||||||
|
rag.text_chunks,
|
||||||
|
rag.llm_response_cache,
|
||||||
|
rag.key_string_value_json_storage_cls,
|
||||||
|
rag.chunks_vdb,
|
||||||
|
rag.relationships_vdb,
|
||||||
|
rag.entities_vdb,
|
||||||
|
rag.graph_storage_cls,
|
||||||
|
rag.chunk_entity_relation_graph,
|
||||||
|
rag.llm_response_cache,
|
||||||
|
]:
|
||||||
|
# set client
|
||||||
|
storage.db = oracle_db
|
||||||
|
|
||||||
# Extract and Insert into LightRAG storage
|
# Extract and Insert into LightRAG storage
|
||||||
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
||||||
|
@@ -15,6 +15,12 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
# ChromaDB Configuration
|
# ChromaDB Configuration
|
||||||
|
CHROMADB_USE_LOCAL_PERSISTENT = False
|
||||||
|
# Local PersistentClient Configuration
|
||||||
|
CHROMADB_LOCAL_PATH = os.environ.get(
|
||||||
|
"CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
|
||||||
|
)
|
||||||
|
# Remote HttpClient Configuration
|
||||||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
||||||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
||||||
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
||||||
@@ -60,7 +66,27 @@ async def create_embedding_function_instance():
|
|||||||
|
|
||||||
async def initialize_rag():
|
async def initialize_rag():
|
||||||
embedding_func_instance = await create_embedding_function_instance()
|
embedding_func_instance = await create_embedding_function_instance()
|
||||||
|
if CHROMADB_USE_LOCAL_PERSISTENT:
|
||||||
|
return LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=embedding_func_instance,
|
||||||
|
vector_storage="ChromaVectorDBStorage",
|
||||||
|
log_level="DEBUG",
|
||||||
|
embedding_batch_num=32,
|
||||||
|
vector_db_storage_cls_kwargs={
|
||||||
|
"local_path": CHROMADB_LOCAL_PATH,
|
||||||
|
"collection_settings": {
|
||||||
|
"hnsw:space": "cosine",
|
||||||
|
"hnsw:construction_ef": 128,
|
||||||
|
"hnsw:search_ef": 128,
|
||||||
|
"hnsw:M": 16,
|
||||||
|
"hnsw:batch_size": 100,
|
||||||
|
"hnsw:sync_threshold": 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
return LightRAG(
|
return LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=gpt_4o_mini_complete,
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
@@ -1,358 +0,0 @@
|
|||||||
"""
|
|
||||||
OpenWebui Lightrag Integration Tool
|
|
||||||
==================================
|
|
||||||
|
|
||||||
This tool enables the integration and use of Lightrag within the OpenWebui environment,
|
|
||||||
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
|
|
||||||
|
|
||||||
Author: ParisNeo (parisneoai@gmail.com)
|
|
||||||
Social:
|
|
||||||
- Twitter: @ParisNeo_AI
|
|
||||||
- Reddit: r/lollms
|
|
||||||
- Instagram: https://www.instagram.com/parisneo_ai/
|
|
||||||
|
|
||||||
License: Apache 2.0
|
|
||||||
Copyright (c) 2024-2025 ParisNeo
|
|
||||||
|
|
||||||
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
|
|
||||||
For more information, visit: https://github.com/ParisNeo/lollms
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- Python 3.8+
|
|
||||||
- OpenWebui
|
|
||||||
- Lightrag
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Tool version
|
|
||||||
__version__ = "1.0.0"
|
|
||||||
__author__ = "ParisNeo"
|
|
||||||
__author_email__ = "parisneoai@gmail.com"
|
|
||||||
__description__ = "Lightrag integration for OpenWebui"
|
|
||||||
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Callable, Any, Literal, Union, List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class StatusEventEmitter:
|
|
||||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit(self, description="Unknown State", status="in_progress", done=False):
|
|
||||||
if self.event_emitter:
|
|
||||||
await self.event_emitter(
|
|
||||||
{
|
|
||||||
"type": "status",
|
|
||||||
"data": {
|
|
||||||
"status": status,
|
|
||||||
"description": description,
|
|
||||||
"done": done,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageEventEmitter:
|
|
||||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit(self, content="Some message"):
|
|
||||||
if self.event_emitter:
|
|
||||||
await self.event_emitter(
|
|
||||||
{
|
|
||||||
"type": "message",
|
|
||||||
"data": {
|
|
||||||
"content": content,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Tools:
|
|
||||||
class Valves(BaseModel):
|
|
||||||
LIGHTRAG_SERVER_URL: str = Field(
|
|
||||||
default="http://localhost:9621/query",
|
|
||||||
description="The base URL for the LightRag server",
|
|
||||||
)
|
|
||||||
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
|
|
||||||
default="hybrid",
|
|
||||||
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
|
|
||||||
)
|
|
||||||
ONLY_NEED_CONTEXT: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="If True, only the context is needed from the LightRag response",
|
|
||||||
)
|
|
||||||
DEBUG_MODE: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="If True, debugging information will be emitted",
|
|
||||||
)
|
|
||||||
KEY: str = Field(
|
|
||||||
default="",
|
|
||||||
description="Optional Bearer Key for authentication",
|
|
||||||
)
|
|
||||||
MAX_ENTITIES: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum number of entities to keep",
|
|
||||||
)
|
|
||||||
MAX_RELATIONSHIPS: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum number of relationships to keep",
|
|
||||||
)
|
|
||||||
MAX_SOURCES: int = Field(
|
|
||||||
default=3,
|
|
||||||
description="Maximum number of sources to keep",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.valves = self.Valves()
|
|
||||||
self.headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"User-Agent": "LightRag-Tool/1.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def query_lightrag(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
__event_emitter__: Callable[[dict], Any] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Query the LightRag server and retrieve information.
|
|
||||||
This function must be called before answering the user question
|
|
||||||
:params query: The query string to send to the LightRag server.
|
|
||||||
:return: The response from the LightRag server in Markdown format or raw response.
|
|
||||||
"""
|
|
||||||
self.status_emitter = StatusEventEmitter(__event_emitter__)
|
|
||||||
self.message_emitter = MessageEventEmitter(__event_emitter__)
|
|
||||||
|
|
||||||
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
|
|
||||||
payload = {
|
|
||||||
"query": query,
|
|
||||||
"mode": str(self.valves.MODE),
|
|
||||||
"stream": False,
|
|
||||||
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
|
|
||||||
}
|
|
||||||
await self.status_emitter.emit("Initializing Lightrag query..")
|
|
||||||
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
|
|
||||||
)
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### Payload Sent to LightRag Server\n```json\n"
|
|
||||||
+ json.dumps(payload, indent=4)
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add Bearer Key to headers if provided
|
|
||||||
if self.valves.KEY:
|
|
||||||
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.status_emitter.emit("Sending request to LightRag server")
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
lightrag_url, json=payload, headers=self.headers, timeout=120
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
await self.status_emitter.emit(
|
|
||||||
status="complete",
|
|
||||||
description="LightRag query Succeeded",
|
|
||||||
done=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
|
|
||||||
if self.valves.ONLY_NEED_CONTEXT:
|
|
||||||
try:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### LightRag Server Response\n```json\n"
|
|
||||||
+ data["response"]
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### Exception\n" + str(ex) + "\n"
|
|
||||||
)
|
|
||||||
return f"Exception: {ex}"
|
|
||||||
return data["response"]
|
|
||||||
else:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### LightRag Server Response\n```json\n"
|
|
||||||
+ data["response"]
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
await self.status_emitter.emit("Lightrag query success")
|
|
||||||
return data["response"]
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
await self.status_emitter.emit(
|
|
||||||
status="error",
|
|
||||||
description=f"Error during LightRag query: {str(e)}",
|
|
||||||
done=True,
|
|
||||||
)
|
|
||||||
return json.dumps({"error": str(e)})
|
|
||||||
|
|
||||||
def extract_code_blocks(
|
|
||||||
self, text: str, return_remaining_text: bool = False
|
|
||||||
) -> Union[List[dict], Tuple[List[dict], str]]:
|
|
||||||
"""
|
|
||||||
This function extracts code blocks from a given text and optionally returns the text without code blocks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
|
|
||||||
return_remaining_text (bool): If True, also returns the text with code blocks removed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[dict], Tuple[List[dict], str]]:
|
|
||||||
- If return_remaining_text is False: Returns only the list of code block dictionaries
|
|
||||||
- If return_remaining_text is True: Returns a tuple containing:
|
|
||||||
* List of code block dictionaries
|
|
||||||
* String containing the text with all code blocks removed
|
|
||||||
|
|
||||||
Each code block dictionary contains:
|
|
||||||
- 'index' (int): The index of the code block in the text
|
|
||||||
- 'file_name' (str): The name of the file extracted from the preceding line, if available
|
|
||||||
- 'content' (str): The content of the code block
|
|
||||||
- 'type' (str): The type of the code block
|
|
||||||
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
|
|
||||||
"""
|
|
||||||
remaining = text
|
|
||||||
bloc_index = 0
|
|
||||||
first_index = 0
|
|
||||||
indices = []
|
|
||||||
text_without_blocks = text
|
|
||||||
|
|
||||||
# Find all code block delimiters
|
|
||||||
while len(remaining) > 0:
|
|
||||||
try:
|
|
||||||
index = remaining.index("```")
|
|
||||||
indices.append(index + first_index)
|
|
||||||
remaining = remaining[index + 3 :]
|
|
||||||
first_index += index + 3
|
|
||||||
bloc_index += 1
|
|
||||||
except Exception:
|
|
||||||
if bloc_index % 2 == 1:
|
|
||||||
index = len(remaining)
|
|
||||||
indices.append(index)
|
|
||||||
remaining = ""
|
|
||||||
|
|
||||||
code_blocks = []
|
|
||||||
is_start = True
|
|
||||||
|
|
||||||
# Process code blocks and build text without blocks if requested
|
|
||||||
if return_remaining_text:
|
|
||||||
text_parts = []
|
|
||||||
last_end = 0
|
|
||||||
|
|
||||||
for index, code_delimiter_position in enumerate(indices):
|
|
||||||
if is_start:
|
|
||||||
block_infos = {
|
|
||||||
"index": len(code_blocks),
|
|
||||||
"file_name": "",
|
|
||||||
"section": "",
|
|
||||||
"content": "",
|
|
||||||
"type": "",
|
|
||||||
"is_complete": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store text before code block if returning remaining text
|
|
||||||
if return_remaining_text:
|
|
||||||
text_parts.append(text[last_end:code_delimiter_position].strip())
|
|
||||||
|
|
||||||
# Check the preceding line for file name
|
|
||||||
preceding_text = text[:code_delimiter_position].strip().splitlines()
|
|
||||||
if preceding_text:
|
|
||||||
last_line = preceding_text[-1].strip()
|
|
||||||
if last_line.startswith("<file_name>") and last_line.endswith(
|
|
||||||
"</file_name>"
|
|
||||||
):
|
|
||||||
file_name = last_line[
|
|
||||||
len("<file_name>") : -len("</file_name>")
|
|
||||||
].strip()
|
|
||||||
block_infos["file_name"] = file_name
|
|
||||||
elif last_line.startswith("## filename:"):
|
|
||||||
file_name = last_line[len("## filename:") :].strip()
|
|
||||||
block_infos["file_name"] = file_name
|
|
||||||
if last_line.startswith("<section>") and last_line.endswith(
|
|
||||||
"</section>"
|
|
||||||
):
|
|
||||||
section = last_line[
|
|
||||||
len("<section>") : -len("</section>")
|
|
||||||
].strip()
|
|
||||||
block_infos["section"] = section
|
|
||||||
|
|
||||||
sub_text = text[code_delimiter_position + 3 :]
|
|
||||||
if len(sub_text) > 0:
|
|
||||||
try:
|
|
||||||
find_space = sub_text.index(" ")
|
|
||||||
except Exception:
|
|
||||||
find_space = int(1e10)
|
|
||||||
try:
|
|
||||||
find_return = sub_text.index("\n")
|
|
||||||
except Exception:
|
|
||||||
find_return = int(1e10)
|
|
||||||
next_index = min(find_return, find_space)
|
|
||||||
if "{" in sub_text[:next_index]:
|
|
||||||
next_index = 0
|
|
||||||
start_pos = next_index
|
|
||||||
|
|
||||||
if code_delimiter_position + 3 < len(text) and text[
|
|
||||||
code_delimiter_position + 3
|
|
||||||
] in ["\n", " ", "\t"]:
|
|
||||||
block_infos["type"] = "language-specific"
|
|
||||||
else:
|
|
||||||
block_infos["type"] = sub_text[:next_index]
|
|
||||||
|
|
||||||
if index + 1 < len(indices):
|
|
||||||
next_pos = indices[index + 1] - code_delimiter_position
|
|
||||||
if (
|
|
||||||
next_pos - 3 < len(sub_text)
|
|
||||||
and sub_text[next_pos - 3] == "`"
|
|
||||||
):
|
|
||||||
block_infos["content"] = sub_text[
|
|
||||||
start_pos : next_pos - 3
|
|
||||||
].strip()
|
|
||||||
block_infos["is_complete"] = True
|
|
||||||
else:
|
|
||||||
block_infos["content"] = sub_text[
|
|
||||||
start_pos:next_pos
|
|
||||||
].strip()
|
|
||||||
block_infos["is_complete"] = False
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
last_end = indices[index + 1] + 3
|
|
||||||
else:
|
|
||||||
block_infos["content"] = sub_text[start_pos:].strip()
|
|
||||||
block_infos["is_complete"] = False
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
last_end = len(text)
|
|
||||||
|
|
||||||
code_blocks.append(block_infos)
|
|
||||||
is_start = False
|
|
||||||
else:
|
|
||||||
is_start = True
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
# Add any remaining text after the last code block
|
|
||||||
if last_end < len(text):
|
|
||||||
text_parts.append(text[last_end:].strip())
|
|
||||||
# Join all non-code parts with newlines
|
|
||||||
text_without_blocks = "\n".join(filter(None, text_parts))
|
|
||||||
return code_blocks, text_without_blocks
|
|
||||||
|
|
||||||
return code_blocks
|
|
||||||
|
|
||||||
def clean(self, csv_content: str):
|
|
||||||
lines = csv_content.splitlines()
|
|
||||||
if lines:
|
|
||||||
# Remove spaces around headers and ensure no spaces between commas
|
|
||||||
header = ",".join([col.strip() for col in lines[0].split(",")])
|
|
||||||
lines[0] = header # Replace the first line with the cleaned header
|
|
||||||
csv_content = "\n".join(lines)
|
|
||||||
return csv_content
|
|
@@ -74,30 +74,38 @@ LLM_MODEL=model_name_of_azure_ai
|
|||||||
LLM_BINDING_API_KEY=api_key_of_azure_ai
|
LLM_BINDING_API_KEY=api_key_of_azure_ai
|
||||||
```
|
```
|
||||||
|
|
||||||
### About Ollama API
|
### 3. Install Lightrag as a Linux Service
|
||||||
|
|
||||||
We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
|
Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file:
|
||||||
|
|
||||||
#### Choose Query mode in chat
|
```text
|
||||||
|
Description=LightRAG Ollama Service
|
||||||
A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
|
WorkingDirectory=<lightrag installed directory>
|
||||||
|
ExecStart=<lightrag installed directory>/lightrag/api/lightrag-api
|
||||||
```
|
|
||||||
/local
|
|
||||||
/global
|
|
||||||
/hybrid
|
|
||||||
/naive
|
|
||||||
/mix
|
|
||||||
/bypass
|
|
||||||
```
|
```
|
||||||
|
|
||||||
For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
|
Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed:
|
||||||
|
|
||||||
"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the LightRAG query results. (If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix)
|
```shell
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# your python virtual environment activation
|
||||||
|
source /home/netman/lightrag-xyj/venv/bin/activate
|
||||||
|
# start lightrag api server
|
||||||
|
lightrag-server
|
||||||
|
```
|
||||||
|
|
||||||
|
Install LightRAG service. If your system is Ubuntu, the following commands will work:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo cp lightrag.service /etc/systemd/system/
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
sudo systemctl start lightrag.service
|
||||||
|
sudo systemctl status lightrag.service
|
||||||
|
sudo systemctl enable lightrag.service
|
||||||
|
```
|
||||||
|
|
||||||
#### Connect Open WebUI to LightRAG
|
|
||||||
|
|
||||||
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -177,7 +185,8 @@ TiDBVectorDBStorage TiDB
|
|||||||
PGVectorStorage Postgres
|
PGVectorStorage Postgres
|
||||||
FaissVectorDBStorage Faiss
|
FaissVectorDBStorage Faiss
|
||||||
QdrantVectorDBStorage Qdrant
|
QdrantVectorDBStorage Qdrant
|
||||||
OracleVectorDBStorag Oracle
|
OracleVectorDBStorage Oracle
|
||||||
|
MongoVectorDBStorage MongoDB
|
||||||
```
|
```
|
||||||
|
|
||||||
* DOC_STATUS_STORAGE:supported implement-name
|
* DOC_STATUS_STORAGE:supported implement-name
|
||||||
@@ -378,7 +387,7 @@ curl -X DELETE "http://localhost:9621/documents"
|
|||||||
|
|
||||||
#### GET /api/version
|
#### GET /api/version
|
||||||
|
|
||||||
Get Ollama version information
|
Get Ollama version information.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:9621/api/version
|
curl http://localhost:9621/api/version
|
||||||
@@ -386,7 +395,7 @@ curl http://localhost:9621/api/version
|
|||||||
|
|
||||||
#### GET /api/tags
|
#### GET /api/tags
|
||||||
|
|
||||||
Get Ollama available models
|
Get Ollama available models.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:9621/api/tags
|
curl http://localhost:9621/api/tags
|
||||||
@@ -394,7 +403,7 @@ curl http://localhost:9621/api/tags
|
|||||||
|
|
||||||
#### POST /api/chat
|
#### POST /api/chat
|
||||||
|
|
||||||
Handle chat completion requests
|
Handle chat completion requests. Routes user queries through LightRAG by selecting query mode based on query prefix. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to underlying LLM.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \
|
curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \
|
||||||
@@ -403,6 +412,10 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso
|
|||||||
|
|
||||||
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||||
|
|
||||||
|
#### POST /api/generate
|
||||||
|
|
||||||
|
Handle generate completion requests. For compatibility purpose, the request is not processed by LightRAG, and will be handled by underlying LLM model.
|
||||||
|
|
||||||
### Utility Endpoints
|
### Utility Endpoints
|
||||||
|
|
||||||
#### GET /health
|
#### GET /health
|
||||||
@@ -412,7 +425,35 @@ Check server health and configuration.
|
|||||||
curl "http://localhost:9621/health"
|
curl "http://localhost:9621/health"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Ollama Emulation
|
||||||
|
|
||||||
|
We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
|
||||||
|
|
||||||
|
### Connect Open WebUI to LightRAG
|
||||||
|
|
||||||
|
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. You'd better install LightRAG as service for this use case.
|
||||||
|
|
||||||
|
Open WebUI's use LLM to do the session title and session keyword generation task. So the Ollama chat chat completion API detects and forwards OpenWebUI session-related requests directly to underlying LLM.
|
||||||
|
|
||||||
|
### Choose Query mode in chat
|
||||||
|
|
||||||
|
A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
|
||||||
|
|
||||||
|
```
|
||||||
|
/local
|
||||||
|
/global
|
||||||
|
/hybrid
|
||||||
|
/naive
|
||||||
|
/mix
|
||||||
|
/bypass
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
|
||||||
|
|
||||||
|
"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the chat history. If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix.
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
Contribute to the project: [Guide](contributor-readme.MD)
|
Contribute to the project: [Guide](contributor-readme.MD)
|
||||||
|
|
||||||
### Running in Development Mode
|
### Running in Development Mode
|
||||||
@@ -470,34 +511,3 @@ This intelligent caching mechanism:
|
|||||||
- Only new documents in the input directory will be processed
|
- Only new documents in the input directory will be processed
|
||||||
- This optimization significantly reduces startup time for subsequent runs
|
- This optimization significantly reduces startup time for subsequent runs
|
||||||
- The working directory (`--working-dir`) stores the vectorized documents database
|
- The working directory (`--working-dir`) stores the vectorized documents database
|
||||||
|
|
||||||
## Install Lightrag as a Linux Service
|
|
||||||
|
|
||||||
Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file:
|
|
||||||
|
|
||||||
```text
|
|
||||||
Description=LightRAG Ollama Service
|
|
||||||
WorkingDirectory=<lightrag installed directory>
|
|
||||||
ExecStart=<lightrag installed directory>/lightrag/api/lightrag-api
|
|
||||||
```
|
|
||||||
|
|
||||||
Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# your python virtual environment activation
|
|
||||||
source /home/netman/lightrag-xyj/venv/bin/activate
|
|
||||||
# start lightrag api server
|
|
||||||
lightrag-server
|
|
||||||
```
|
|
||||||
|
|
||||||
Install LightRAG service. If your system is Ubuntu, the following commands will work:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
sudo cp lightrag.service /etc/systemd/system/
|
|
||||||
sudo systemctl daemon-reload
|
|
||||||
sudo systemctl start lightrag.service
|
|
||||||
sudo systemctl status lightrag.service
|
|
||||||
sudo systemctl enable lightrag.service
|
|
||||||
```
|
|
||||||
|
@@ -3,7 +3,6 @@ from fastapi import (
|
|||||||
HTTPException,
|
HTTPException,
|
||||||
File,
|
File,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
Form,
|
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -14,7 +13,7 @@ import re
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Any, Optional, Union, Dict
|
from typing import List, Any, Optional, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.types import GPTKeywordExtractionFormat
|
from lightrag.types import GPTKeywordExtractionFormat
|
||||||
@@ -34,6 +33,9 @@ from starlette.status import HTTP_403_FORBIDDEN
|
|||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import configparser
|
import configparser
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from .ollama_api import (
|
from .ollama_api import (
|
||||||
OllamaAPI,
|
OllamaAPI,
|
||||||
@@ -159,8 +161,12 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||||||
ASCIIColors.yellow(f"{args.host}")
|
ASCIIColors.yellow(f"{args.host}")
|
||||||
ASCIIColors.white(" ├─ Port: ", end="")
|
ASCIIColors.white(" ├─ Port: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.port}")
|
ASCIIColors.yellow(f"{args.port}")
|
||||||
ASCIIColors.white(" └─ SSL Enabled: ", end="")
|
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
|
||||||
|
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.ssl}")
|
ASCIIColors.yellow(f"{args.ssl}")
|
||||||
|
ASCIIColors.white(" └─ API Key: ", end="")
|
||||||
|
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
||||||
if args.ssl:
|
if args.ssl:
|
||||||
ASCIIColors.white(" ├─ SSL Cert: ", end="")
|
ASCIIColors.white(" ├─ SSL Cert: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.ssl_certfile}")
|
ASCIIColors.yellow(f"{args.ssl_certfile}")
|
||||||
@@ -229,10 +235,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||||||
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
||||||
ASCIIColors.white(" ├─ Log Level: ", end="")
|
ASCIIColors.white(" ├─ Log Level: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.log_level}")
|
ASCIIColors.yellow(f"{args.log_level}")
|
||||||
ASCIIColors.white(" ├─ Timeout: ", end="")
|
ASCIIColors.white(" └─ Timeout: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
||||||
ASCIIColors.white(" └─ API Key: ", end="")
|
|
||||||
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
|
||||||
|
|
||||||
# Server Status
|
# Server Status
|
||||||
ASCIIColors.green("\n✨ Server starting up...\n")
|
ASCIIColors.green("\n✨ Server starting up...\n")
|
||||||
@@ -564,6 +568,10 @@ def parse_args() -> argparse.Namespace:
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# conver relative path to absolute path
|
||||||
|
args.working_dir = os.path.abspath(args.working_dir)
|
||||||
|
args.input_dir = os.path.abspath(args.input_dir)
|
||||||
|
|
||||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||||
|
|
||||||
return args
|
return args
|
||||||
@@ -595,6 +603,7 @@ class DocumentManager:
|
|||||||
"""Scan input directory for new files"""
|
"""Scan input directory for new files"""
|
||||||
new_files = []
|
new_files = []
|
||||||
for ext in self.supported_extensions:
|
for ext in self.supported_extensions:
|
||||||
|
logger.info(f"Scanning for {ext} files in {self.input_dir}")
|
||||||
for file_path in self.input_dir.rglob(f"*{ext}"):
|
for file_path in self.input_dir.rglob(f"*{ext}"):
|
||||||
if file_path not in self.indexed_files:
|
if file_path not in self.indexed_files:
|
||||||
new_files.append(file_path)
|
new_files.append(file_path)
|
||||||
@@ -628,9 +637,47 @@ class SearchMode(str, Enum):
|
|||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
"""Specifies the retrieval mode"""
|
||||||
mode: SearchMode = SearchMode.hybrid
|
mode: SearchMode = SearchMode.hybrid
|
||||||
stream: bool = False
|
|
||||||
only_need_context: bool = False
|
"""If True, enables streaming output for real-time responses."""
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
|
||||||
|
"""If True, only returns the retrieved context without generating a response."""
|
||||||
|
only_need_context: Optional[bool] = None
|
||||||
|
|
||||||
|
"""If True, only returns the generated prompt without producing a response."""
|
||||||
|
only_need_prompt: Optional[bool] = None
|
||||||
|
|
||||||
|
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||||
|
response_type: Optional[str] = None
|
||||||
|
|
||||||
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
|
||||||
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
|
max_token_for_text_unit: Optional[int] = None
|
||||||
|
|
||||||
|
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||||
|
max_token_for_global_context: Optional[int] = None
|
||||||
|
|
||||||
|
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||||
|
max_token_for_local_context: Optional[int] = None
|
||||||
|
|
||||||
|
"""List of high-level keywords to prioritize in retrieval."""
|
||||||
|
hl_keywords: Optional[List[str]] = None
|
||||||
|
|
||||||
|
"""List of low-level keywords to refine retrieval focus."""
|
||||||
|
ll_keywords: Optional[List[str]] = None
|
||||||
|
|
||||||
|
"""Stores past conversation history to maintain context.
|
||||||
|
Format: [{"role": "user/assistant", "content": "message"}].
|
||||||
|
"""
|
||||||
|
conversation_history: Optional[List[dict[str, Any]]] = None
|
||||||
|
|
||||||
|
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||||
|
history_turns: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
@@ -639,13 +686,38 @@ class QueryResponse(BaseModel):
|
|||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InsertResponse(BaseModel):
|
class InsertResponse(BaseModel):
|
||||||
status: str
|
status: str
|
||||||
message: str
|
message: str
|
||||||
document_count: int
|
|
||||||
|
|
||||||
|
def QueryRequestToQueryParams(request: QueryRequest):
|
||||||
|
param = QueryParam(mode=request.mode, stream=request.stream)
|
||||||
|
if request.only_need_context is not None:
|
||||||
|
param.only_need_context = request.only_need_context
|
||||||
|
if request.only_need_prompt is not None:
|
||||||
|
param.only_need_prompt = request.only_need_prompt
|
||||||
|
if request.response_type is not None:
|
||||||
|
param.response_type = request.response_type
|
||||||
|
if request.top_k is not None:
|
||||||
|
param.top_k = request.top_k
|
||||||
|
if request.max_token_for_text_unit is not None:
|
||||||
|
param.max_token_for_text_unit = request.max_token_for_text_unit
|
||||||
|
if request.max_token_for_global_context is not None:
|
||||||
|
param.max_token_for_global_context = request.max_token_for_global_context
|
||||||
|
if request.max_token_for_local_context is not None:
|
||||||
|
param.max_token_for_local_context = request.max_token_for_local_context
|
||||||
|
if request.hl_keywords is not None:
|
||||||
|
param.hl_keywords = request.hl_keywords
|
||||||
|
if request.ll_keywords is not None:
|
||||||
|
param.ll_keywords = request.ll_keywords
|
||||||
|
if request.conversation_history is not None:
|
||||||
|
param.conversation_history = request.conversation_history
|
||||||
|
if request.history_turns is not None:
|
||||||
|
param.history_turns = request.history_turns
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
@@ -659,7 +731,9 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
# If API key is configured, use proper authentication
|
# If API key is configured, use proper authentication
|
||||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
|
async def api_key_auth(
|
||||||
|
api_key_header_value: Optional[str] = Security(api_key_header),
|
||||||
|
):
|
||||||
if not api_key_header_value:
|
if not api_key_header_value:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
||||||
@@ -675,6 +749,7 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
|
|
||||||
# Global configuration
|
# Global configuration
|
||||||
global_top_k = 60 # default value
|
global_top_k = 60 # default value
|
||||||
|
temp_prefix = "__tmp_" # prefix for temporary files
|
||||||
|
|
||||||
|
|
||||||
def create_app(args):
|
def create_app(args):
|
||||||
@@ -842,10 +917,19 @@ def create_app(args):
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_cors_origins():
|
||||||
|
"""Get allowed origins from environment variable
|
||||||
|
Returns a list of allowed origins, defaults to ["*"] if not set
|
||||||
|
"""
|
||||||
|
origins_str = os.getenv("CORS_ORIGINS", "*")
|
||||||
|
if origins_str == "*":
|
||||||
|
return ["*"]
|
||||||
|
return [origin.strip() for origin in origins_str.split(",")]
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=get_cors_origins(),
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
@@ -1116,61 +1200,194 @@ def create_app(args):
|
|||||||
("llm_response_cache", rag.llm_response_cache),
|
("llm_response_cache", rag.llm_response_cache),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def index_file(file_path: Union[str, Path]) -> None:
|
async def pipeline_enqueue_file(file_path: Path) -> bool:
|
||||||
"""Index all files inside the folder with support for multiple file formats
|
"""Add a file to the queue for processing
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the file to be indexed (str or Path object)
|
file_path: Path to the saved file
|
||||||
|
Returns:
|
||||||
Raises:
|
bool: True if the file was successfully enqueued, False otherwise
|
||||||
ValueError: If file format is not supported
|
|
||||||
FileNotFoundError: If file doesn't exist
|
|
||||||
"""
|
"""
|
||||||
if not pm.is_installed("aiofiles"):
|
try:
|
||||||
pm.install("aiofiles")
|
|
||||||
|
|
||||||
# Convert to Path object if string
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
# Check if file exists
|
|
||||||
if not file_path.exists():
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
content = ""
|
content = ""
|
||||||
# Get file extension in lowercase
|
|
||||||
ext = file_path.suffix.lower()
|
ext = file_path.suffix.lower()
|
||||||
|
|
||||||
|
file = None
|
||||||
|
async with aiofiles.open(file_path, "rb") as f:
|
||||||
|
file = await f.read()
|
||||||
|
|
||||||
|
# Process based on file type
|
||||||
match ext:
|
match ext:
|
||||||
case ".txt" | ".md":
|
case ".txt" | ".md":
|
||||||
# Text files handling
|
content = file.decode("utf-8")
|
||||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
case ".pdf":
|
||||||
content = await f.read()
|
if not pm.is_installed("pypdf2"):
|
||||||
|
pm.install("pypdf2")
|
||||||
|
from PyPDF2 import PdfReader
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
case ".pdf" | ".docx" | ".pptx" | ".xlsx":
|
pdf_file = BytesIO(file)
|
||||||
if not pm.is_installed("docling"):
|
reader = PdfReader(pdf_file)
|
||||||
pm.install("docling")
|
for page in reader.pages:
|
||||||
from docling.document_converter import DocumentConverter
|
content += page.extract_text() + "\n"
|
||||||
|
case ".docx":
|
||||||
|
if not pm.is_installed("docx"):
|
||||||
|
pm.install("docx")
|
||||||
|
from docx import Document
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
async def convert_doc():
|
docx_content = await file.read()
|
||||||
def sync_convert():
|
docx_file = BytesIO(docx_content)
|
||||||
converter = DocumentConverter()
|
doc = Document(docx_file)
|
||||||
result = converter.convert(file_path)
|
content = "\n".join(
|
||||||
return result.document.export_to_markdown()
|
[paragraph.text for paragraph in doc.paragraphs]
|
||||||
|
)
|
||||||
return await asyncio.to_thread(sync_convert)
|
case ".pptx":
|
||||||
|
if not pm.is_installed("pptx"):
|
||||||
content = await convert_doc()
|
pm.install("pptx")
|
||||||
|
from pptx import Presentation # type: ignore
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
pptx_content = await file.read()
|
||||||
|
pptx_file = BytesIO(pptx_content)
|
||||||
|
prs = Presentation(pptx_file)
|
||||||
|
for slide in prs.slides:
|
||||||
|
for shape in slide.shapes:
|
||||||
|
if hasattr(shape, "text"):
|
||||||
|
content += shape.text + "\n"
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported file format: {ext}")
|
logging.error(
|
||||||
|
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# Insert content into RAG system
|
# Insert into the RAG queue
|
||||||
if content:
|
if content:
|
||||||
await rag.ainsert(content)
|
await rag.apipeline_enqueue_documents(content)
|
||||||
doc_manager.mark_as_indexed(file_path)
|
logging.info(
|
||||||
logging.info(f"Successfully indexed file: {file_path}")
|
f"Successfully processed and enqueued file: {file_path.name}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
logging.warning(f"No content extracted from file: {file_path}")
|
logging.error(
|
||||||
|
f"No content could be extracted from file: {file_path.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
|
||||||
|
)
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
if file_path.name.startswith(temp_prefix):
|
||||||
|
# Clean up the temporary file after indexing
|
||||||
|
try:
|
||||||
|
file_path.unlink()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def pipeline_index_file(file_path: Path):
|
||||||
|
"""Index a file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the saved file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if await pipeline_enqueue_file(file_path):
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def pipeline_index_files(file_paths: List[Path]):
|
||||||
|
"""Index multiple files concurrently
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: Paths to the files to index
|
||||||
|
"""
|
||||||
|
if not file_paths:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
enqueued = False
|
||||||
|
|
||||||
|
if len(file_paths) == 1:
|
||||||
|
enqueued = await pipeline_enqueue_file(file_paths[0])
|
||||||
|
else:
|
||||||
|
tasks = [pipeline_enqueue_file(path) for path in file_paths]
|
||||||
|
enqueued = any(await asyncio.gather(*tasks))
|
||||||
|
|
||||||
|
if enqueued:
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error indexing files: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def pipeline_index_texts(texts: List[str]):
|
||||||
|
"""Index a list of texts
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The texts to index
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return
|
||||||
|
await rag.apipeline_enqueue_documents(texts)
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
|
||||||
|
async def save_temp_file(file: UploadFile = File(...)) -> Path:
|
||||||
|
"""Save the uploaded file to a temporary location
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The uploaded file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: The path to the saved file
|
||||||
|
"""
|
||||||
|
# Generate unique filename to avoid conflicts
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
||||||
|
|
||||||
|
# Create a temporary file to save the uploaded content
|
||||||
|
temp_path = doc_manager.input_dir / "temp" / unique_filename
|
||||||
|
temp_path.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Save the file
|
||||||
|
with open(temp_path, "wb") as buffer:
|
||||||
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
return temp_path
|
||||||
|
|
||||||
|
async def run_scanning_process():
|
||||||
|
"""Background task to scan and index documents"""
|
||||||
|
global scan_progress
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_files = doc_manager.scan_directory_for_new_files()
|
||||||
|
scan_progress["total_files"] = len(new_files)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(new_files)} new files to index.")
|
||||||
|
for file_path in new_files:
|
||||||
|
try:
|
||||||
|
with progress_lock:
|
||||||
|
scan_progress["current_file"] = os.path.basename(file_path)
|
||||||
|
|
||||||
|
await pipeline_index_file(file_path)
|
||||||
|
|
||||||
|
with progress_lock:
|
||||||
|
scan_progress["indexed_count"] += 1
|
||||||
|
scan_progress["progress"] = (
|
||||||
|
scan_progress["indexed_count"]
|
||||||
|
/ scan_progress["total_files"]
|
||||||
|
) * 100
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during scanning process: {str(e)}")
|
||||||
|
finally:
|
||||||
|
with progress_lock:
|
||||||
|
scan_progress["is_scanning"] = False
|
||||||
|
|
||||||
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
|
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
|
||||||
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
||||||
@@ -1190,37 +1407,6 @@ def create_app(args):
|
|||||||
|
|
||||||
return {"status": "scanning_started"}
|
return {"status": "scanning_started"}
|
||||||
|
|
||||||
async def run_scanning_process():
|
|
||||||
"""Background task to scan and index documents"""
|
|
||||||
global scan_progress
|
|
||||||
|
|
||||||
try:
|
|
||||||
new_files = doc_manager.scan_directory_for_new_files()
|
|
||||||
scan_progress["total_files"] = len(new_files)
|
|
||||||
|
|
||||||
for file_path in new_files:
|
|
||||||
try:
|
|
||||||
with progress_lock:
|
|
||||||
scan_progress["current_file"] = os.path.basename(file_path)
|
|
||||||
|
|
||||||
await index_file(file_path)
|
|
||||||
|
|
||||||
with progress_lock:
|
|
||||||
scan_progress["indexed_count"] += 1
|
|
||||||
scan_progress["progress"] = (
|
|
||||||
scan_progress["indexed_count"]
|
|
||||||
/ scan_progress["total_files"]
|
|
||||||
) * 100
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error during scanning process: {str(e)}")
|
|
||||||
finally:
|
|
||||||
with progress_lock:
|
|
||||||
scan_progress["is_scanning"] = False
|
|
||||||
|
|
||||||
@app.get("/documents/scan-progress")
|
@app.get("/documents/scan-progress")
|
||||||
async def get_scan_progress():
|
async def get_scan_progress():
|
||||||
"""Get the current scanning progress"""
|
"""Get the current scanning progress"""
|
||||||
@@ -1228,7 +1414,9 @@ def create_app(args):
|
|||||||
return scan_progress
|
return scan_progress
|
||||||
|
|
||||||
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
|
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
|
||||||
async def upload_to_input_dir(file: UploadFile = File(...)):
|
async def upload_to_input_dir(
|
||||||
|
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Endpoint for uploading a file to the input directory and indexing it.
|
Endpoint for uploading a file to the input directory and indexing it.
|
||||||
|
|
||||||
@@ -1237,6 +1425,7 @@ def create_app(args):
|
|||||||
indexes it for retrieval, and returns a success status with relevant details.
|
indexes it for retrieval, and returns a success status with relevant details.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
background_tasks: FastAPI BackgroundTasks for async processing
|
||||||
file (UploadFile): The file to be uploaded. It must have an allowed extension as per
|
file (UploadFile): The file to be uploaded. It must have an allowed extension as per
|
||||||
`doc_manager.supported_extensions`.
|
`doc_manager.supported_extensions`.
|
||||||
|
|
||||||
@@ -1261,15 +1450,175 @@ def create_app(args):
|
|||||||
with open(file_path, "wb") as buffer:
|
with open(file_path, "wb") as buffer:
|
||||||
shutil.copyfileobj(file.file, buffer)
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
|
||||||
# Immediately index the uploaded file
|
# Add to background tasks
|
||||||
await index_file(file_path)
|
background_tasks.add_task(pipeline_index_file, file_path)
|
||||||
|
|
||||||
return {
|
return InsertResponse(
|
||||||
"status": "success",
|
status="success",
|
||||||
"message": f"File uploaded and indexed: {file.filename}",
|
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
|
||||||
"total_documents": len(doc_manager.indexed_files),
|
)
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/documents/text",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def insert_text(
|
||||||
|
request: InsertTextRequest, background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Insert text into the Retrieval-Augmented Generation (RAG) system.
|
||||||
|
|
||||||
|
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (InsertTextRequest): The request body containing the text to be inserted.
|
||||||
|
background_tasks: FastAPI BackgroundTasks for async processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
background_tasks.add_task(pipeline_index_texts, [request.text])
|
||||||
|
return InsertResponse(
|
||||||
|
status="success",
|
||||||
|
message="Text successfully received. Processing will continue in background.",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error /documents/text: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/documents/file",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def insert_file(
|
||||||
|
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
|
):
|
||||||
|
"""Insert a file directly into the RAG system
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_tasks: FastAPI BackgroundTasks for async processing
|
||||||
|
file: Uploaded file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InsertResponse: Status of the insertion operation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: For unsupported file types or processing errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not doc_manager.is_supported_file(file.filename):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a temporary file to save the uploaded content
|
||||||
|
temp_path = save_temp_file(file)
|
||||||
|
|
||||||
|
# Add to background tasks
|
||||||
|
background_tasks.add_task(pipeline_index_file, temp_path)
|
||||||
|
|
||||||
|
return InsertResponse(
|
||||||
|
status="success",
|
||||||
|
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error /documents/file: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/documents/batch",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def insert_batch(
|
||||||
|
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
|
||||||
|
):
|
||||||
|
"""Process multiple files in batch mode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_tasks: FastAPI BackgroundTasks for async processing
|
||||||
|
files: List of files to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InsertResponse: Status of the batch insertion operation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: For processing errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
inserted_count = 0
|
||||||
|
failed_files = []
|
||||||
|
temp_files = []
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if doc_manager.is_supported_file(file.filename):
|
||||||
|
# Create a temporary file to save the uploaded content
|
||||||
|
temp_files.append(save_temp_file(file))
|
||||||
|
inserted_count += 1
|
||||||
|
else:
|
||||||
|
failed_files.append(f"{file.filename} (unsupported type)")
|
||||||
|
|
||||||
|
if temp_files:
|
||||||
|
background_tasks.add_task(pipeline_index_files, temp_files)
|
||||||
|
|
||||||
|
# Prepare status message
|
||||||
|
if inserted_count == len(files):
|
||||||
|
status = "success"
|
||||||
|
status_message = f"Successfully inserted all {inserted_count} documents"
|
||||||
|
elif inserted_count > 0:
|
||||||
|
status = "partial_success"
|
||||||
|
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
|
||||||
|
if failed_files:
|
||||||
|
status_message += f". Failed files: {', '.join(failed_files)}"
|
||||||
|
else:
|
||||||
|
status = "failure"
|
||||||
|
status_message = "No documents were successfully inserted"
|
||||||
|
if failed_files:
|
||||||
|
status_message += f". Failed files: {', '.join(failed_files)}"
|
||||||
|
|
||||||
|
return InsertResponse(status=status, message=status_message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error /documents/batch: {file.filename}: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.delete(
|
||||||
|
"/documents",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def clear_documents():
|
||||||
|
"""
|
||||||
|
Clear all documents from the LightRAG system.
|
||||||
|
|
||||||
|
This endpoint deletes all text chunks, entities vector database, and relationships vector database,
|
||||||
|
effectively clearing all documents from the LightRAG system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InsertResponse: A response object containing the status, message, and the new document count (0 in this case).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
rag.text_chunks = []
|
||||||
|
rag.entities_vdb = None
|
||||||
|
rag.relationships_vdb = None
|
||||||
|
return InsertResponse(
|
||||||
|
status="success", message="All documents cleared successfully"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error DELETE /documents: {str(e)}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@@ -1280,12 +1629,7 @@ def create_app(args):
|
|||||||
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
request (QueryRequest): A Pydantic model containing the following fields:
|
request (QueryRequest): The request object containing the query parameters.
|
||||||
- query (str): The text of the user's query.
|
|
||||||
- mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
|
|
||||||
- stream (bool): Optional. Determines if the response should be streamed.
|
|
||||||
- only_need_context (bool): Optional. If true, returns only the context without further processing.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
QueryResponse: A Pydantic model containing the result of the query processing.
|
QueryResponse: A Pydantic model containing the result of the query processing.
|
||||||
If a string is returned (e.g., cache hit), it's directly returned.
|
If a string is returned (e.g., cache hit), it's directly returned.
|
||||||
@@ -1297,13 +1641,7 @@ def create_app(args):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await rag.aquery(
|
response = await rag.aquery(
|
||||||
request.query,
|
request.query, param=QueryRequestToQueryParams(request)
|
||||||
param=QueryParam(
|
|
||||||
mode=request.mode,
|
|
||||||
stream=request.stream,
|
|
||||||
only_need_context=request.only_need_context,
|
|
||||||
top_k=global_top_k,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# If response is a string (e.g. cache hit), return directly
|
# If response is a string (e.g. cache hit), return directly
|
||||||
@@ -1311,16 +1649,16 @@ def create_app(args):
|
|||||||
return QueryResponse(response=response)
|
return QueryResponse(response=response)
|
||||||
|
|
||||||
# If it's an async generator, decide whether to stream based on stream parameter
|
# If it's an async generator, decide whether to stream based on stream parameter
|
||||||
if request.stream:
|
if request.stream or hasattr(response, "__aiter__"):
|
||||||
result = ""
|
result = ""
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
result += chunk
|
result += chunk
|
||||||
return QueryResponse(response=result)
|
return QueryResponse(response=result)
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
result = json.dumps(response, indent=2)
|
||||||
|
return QueryResponse(response=result)
|
||||||
else:
|
else:
|
||||||
result = ""
|
return QueryResponse(response=str(response))
|
||||||
async for chunk in response:
|
|
||||||
result += chunk
|
|
||||||
return QueryResponse(response=result)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -1338,14 +1676,11 @@ def create_app(args):
|
|||||||
StreamingResponse: A streaming response containing the RAG query results.
|
StreamingResponse: A streaming response containing the RAG query results.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = QueryRequestToQueryParams(request)
|
||||||
|
|
||||||
|
params.stream = True
|
||||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||||
request.query,
|
request.query, param=params
|
||||||
param=QueryParam(
|
|
||||||
mode=request.mode,
|
|
||||||
stream=True,
|
|
||||||
only_need_context=request.only_need_context,
|
|
||||||
top_k=global_top_k,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@@ -1371,265 +1706,13 @@ def create_app(args):
|
|||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"Content-Type": "application/x-ndjson",
|
"Content-Type": "application/x-ndjson",
|
||||||
"Access-Control-Allow-Origin": "*",
|
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type",
|
|
||||||
"X-Accel-Buffering": "no", # Disable Nginx buffering
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/documents/text",
|
|
||||||
response_model=InsertResponse,
|
|
||||||
dependencies=[Depends(optional_api_key)],
|
|
||||||
)
|
|
||||||
async def insert_text(request: InsertTextRequest):
|
|
||||||
"""
|
|
||||||
Insert text into the Retrieval-Augmented Generation (RAG) system.
|
|
||||||
|
|
||||||
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (InsertTextRequest): The request body containing the text to be inserted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await rag.ainsert(request.text)
|
|
||||||
return InsertResponse(
|
|
||||||
status="success",
|
|
||||||
message="Text successfully inserted",
|
|
||||||
document_count=1,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/documents/file",
|
|
||||||
response_model=InsertResponse,
|
|
||||||
dependencies=[Depends(optional_api_key)],
|
|
||||||
)
|
|
||||||
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
|
|
||||||
"""Insert a file directly into the RAG system
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file: Uploaded file
|
|
||||||
description: Optional description of the file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
InsertResponse: Status of the insertion operation
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: For unsupported file types or processing errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
content = ""
|
|
||||||
# Get file extension in lowercase
|
|
||||||
ext = Path(file.filename).suffix.lower()
|
|
||||||
|
|
||||||
match ext:
|
|
||||||
case ".txt" | ".md":
|
|
||||||
# Text files handling
|
|
||||||
text_content = await file.read()
|
|
||||||
content = text_content.decode("utf-8")
|
|
||||||
|
|
||||||
case ".pdf" | ".docx" | ".pptx" | ".xlsx":
|
|
||||||
if not pm.is_installed("docling"):
|
|
||||||
pm.install("docling")
|
|
||||||
from docling.document_converter import DocumentConverter
|
|
||||||
|
|
||||||
# Create a temporary file to save the uploaded content
|
|
||||||
temp_path = Path("temp") / file.filename
|
|
||||||
temp_path.parent.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# Save the uploaded file
|
|
||||||
with temp_path.open("wb") as f:
|
|
||||||
f.write(await file.read())
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
async def convert_doc():
|
|
||||||
def sync_convert():
|
|
||||||
converter = DocumentConverter()
|
|
||||||
result = converter.convert(str(temp_path))
|
|
||||||
return result.document.export_to_markdown()
|
|
||||||
|
|
||||||
return await asyncio.to_thread(sync_convert)
|
|
||||||
|
|
||||||
content = await convert_doc()
|
|
||||||
finally:
|
|
||||||
# Clean up the temporary file
|
|
||||||
temp_path.unlink()
|
|
||||||
|
|
||||||
# Insert content into RAG system
|
|
||||||
if content:
|
|
||||||
# Add description if provided
|
|
||||||
if description:
|
|
||||||
content = f"{description}\n\n{content}"
|
|
||||||
|
|
||||||
await rag.ainsert(content)
|
|
||||||
logging.info(f"Successfully indexed file: {file.filename}")
|
|
||||||
|
|
||||||
return InsertResponse(
|
|
||||||
status="success",
|
|
||||||
message=f"File '{file.filename}' successfully inserted",
|
|
||||||
document_count=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="No content could be extracted from the file",
|
|
||||||
)
|
|
||||||
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
raise HTTPException(status_code=400, detail="File encoding not supported")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error processing file {file.filename}: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/documents/batch",
|
|
||||||
response_model=InsertResponse,
|
|
||||||
dependencies=[Depends(optional_api_key)],
|
|
||||||
)
|
|
||||||
async def insert_batch(files: List[UploadFile] = File(...)):
|
|
||||||
"""Process multiple files in batch mode
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: List of files to process
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
InsertResponse: Status of the batch insertion operation
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: For processing errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
inserted_count = 0
|
|
||||||
failed_files = []
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
content = ""
|
|
||||||
ext = Path(file.filename).suffix.lower()
|
|
||||||
|
|
||||||
match ext:
|
|
||||||
case ".txt" | ".md":
|
|
||||||
text_content = await file.read()
|
|
||||||
content = text_content.decode("utf-8")
|
|
||||||
|
|
||||||
case ".pdf":
|
|
||||||
if not pm.is_installed("pypdf2"):
|
|
||||||
pm.install("pypdf2")
|
|
||||||
from PyPDF2 import PdfReader
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
pdf_content = await file.read()
|
|
||||||
pdf_file = BytesIO(pdf_content)
|
|
||||||
reader = PdfReader(pdf_file)
|
|
||||||
for page in reader.pages:
|
|
||||||
content += page.extract_text() + "\n"
|
|
||||||
|
|
||||||
case ".docx":
|
|
||||||
if not pm.is_installed("docx"):
|
|
||||||
pm.install("docx")
|
|
||||||
from docx import Document
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
docx_content = await file.read()
|
|
||||||
docx_file = BytesIO(docx_content)
|
|
||||||
doc = Document(docx_file)
|
|
||||||
content = "\n".join(
|
|
||||||
[paragraph.text for paragraph in doc.paragraphs]
|
|
||||||
)
|
|
||||||
|
|
||||||
case ".pptx":
|
|
||||||
if not pm.is_installed("pptx"):
|
|
||||||
pm.install("pptx")
|
|
||||||
from pptx import Presentation # type: ignore
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
pptx_content = await file.read()
|
|
||||||
pptx_file = BytesIO(pptx_content)
|
|
||||||
prs = Presentation(pptx_file)
|
|
||||||
for slide in prs.slides:
|
|
||||||
for shape in slide.shapes:
|
|
||||||
if hasattr(shape, "text"):
|
|
||||||
content += shape.text + "\n"
|
|
||||||
|
|
||||||
case _:
|
|
||||||
failed_files.append(f"{file.filename} (unsupported type)")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if content:
|
|
||||||
await rag.ainsert(content)
|
|
||||||
inserted_count += 1
|
|
||||||
logging.info(f"Successfully indexed file: {file.filename}")
|
|
||||||
else:
|
|
||||||
failed_files.append(f"{file.filename} (no content extracted)")
|
|
||||||
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
failed_files.append(f"{file.filename} (encoding error)")
|
|
||||||
except Exception as e:
|
|
||||||
failed_files.append(f"{file.filename} ({str(e)})")
|
|
||||||
logging.error(f"Error processing file {file.filename}: {str(e)}")
|
|
||||||
|
|
||||||
# Prepare status message
|
|
||||||
if inserted_count == len(files):
|
|
||||||
status = "success"
|
|
||||||
status_message = f"Successfully inserted all {inserted_count} documents"
|
|
||||||
elif inserted_count > 0:
|
|
||||||
status = "partial_success"
|
|
||||||
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
|
|
||||||
if failed_files:
|
|
||||||
status_message += f". Failed files: {', '.join(failed_files)}"
|
|
||||||
else:
|
|
||||||
status = "failure"
|
|
||||||
status_message = "No documents were successfully inserted"
|
|
||||||
if failed_files:
|
|
||||||
status_message += f". Failed files: {', '.join(failed_files)}"
|
|
||||||
|
|
||||||
return InsertResponse(
|
|
||||||
status=status,
|
|
||||||
message=status_message,
|
|
||||||
document_count=inserted_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Batch processing error: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app.delete(
|
|
||||||
"/documents",
|
|
||||||
response_model=InsertResponse,
|
|
||||||
dependencies=[Depends(optional_api_key)],
|
|
||||||
)
|
|
||||||
async def clear_documents():
|
|
||||||
"""
|
|
||||||
Clear all documents from the LightRAG system.
|
|
||||||
|
|
||||||
This endpoint deletes all text chunks, entities vector database, and relationships vector database,
|
|
||||||
effectively clearing all documents from the LightRAG system.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
InsertResponse: A response object containing the status, message, and the new document count (0 in this case).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
rag.text_chunks = []
|
|
||||||
rag.entities_vdb = None
|
|
||||||
rag.relationships_vdb = None
|
|
||||||
return InsertResponse(
|
|
||||||
status="success",
|
|
||||||
message="All documents cleared successfully",
|
|
||||||
document_count=0,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
# query all graph labels
|
# query all graph labels
|
||||||
@app.get("/graph/label/list")
|
@app.get("/graph/label/list")
|
||||||
async def get_graph_labels():
|
async def get_graph_labels():
|
||||||
|
@@ -316,9 +316,7 @@ class OllamaAPI:
|
|||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"Content-Type": "application/x-ndjson",
|
"Content-Type": "application/x-ndjson",
|
||||||
"Access-Control-Allow-Origin": "*",
|
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -534,9 +532,7 @@ class OllamaAPI:
|
|||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"Content-Type": "application/x-ndjson",
|
"Content-Type": "application/x-ndjson",
|
||||||
"Access-Control-Allow-Origin": "*",
|
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -69,7 +69,7 @@ class QueryParam:
|
|||||||
ll_keywords: list[str] = field(default_factory=list)
|
ll_keywords: list[str] = field(default_factory=list)
|
||||||
"""List of low-level keywords to refine retrieval focus."""
|
"""List of low-level keywords to refine retrieval focus."""
|
||||||
|
|
||||||
conversation_history: list[dict[str, Any]] = field(default_factory=list)
|
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
||||||
"""Stores past conversation history to maintain context.
|
"""Stores past conversation history to maintain context.
|
||||||
Format: [{"role": "user/assistant", "content": "message"}].
|
Format: [{"role": "user/assistant", "content": "message"}].
|
||||||
"""
|
"""
|
||||||
@@ -83,19 +83,15 @@ class StorageNameSpace:
|
|||||||
namespace: str
|
namespace: str
|
||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Commit the storage operations after indexing"""
|
"""Commit the storage operations after indexing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def query_done_callback(self):
|
|
||||||
"""Commit the storage operations after querying"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
meta_fields: set = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -106,12 +102,20 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
"""Delete a single entity by its name"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
"""Delete relations for a given entity by scanning metadata"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(StorageNameSpace):
|
class BaseKVStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc | None = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
@@ -130,50 +134,75 @@ class BaseKVStorage(StorageNameSpace):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace):
|
class BaseGraphStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc = None
|
embedding_func: EmbeddingFunc | None = None
|
||||||
|
"""Check if a node exists in the graph."""
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Check if an edge exists in the graph."""
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get the degree of a node."""
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get the degree of an edge."""
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
"""Get a node by its id."""
|
||||||
|
|
||||||
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get an edge by its source and target node ids."""
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node_edges(
|
"""Get all edges connected to a node."""
|
||||||
self, source_node_id: str
|
|
||||||
) -> Union[list[tuple[str, str]], None]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
"""Upsert a node into the graph."""
|
||||||
|
|
||||||
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Upsert an edge into the graph."""
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
"""Delete a node from the graph."""
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
"""Embed nodes using an algorithm."""
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
|
||||||
|
"""Get all labels in the graph."""
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get a knowledge graph of a node."""
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
@@ -205,9 +234,9 @@ class DocProcessingStatus:
|
|||||||
"""ISO format timestamp when document was created"""
|
"""ISO format timestamp when document was created"""
|
||||||
updated_at: str
|
updated_at: str
|
||||||
"""ISO format timestamp when document was last updated"""
|
"""ISO format timestamp when document was last updated"""
|
||||||
chunks_count: Optional[int] = None
|
chunks_count: int | None = None
|
||||||
"""Number of chunks after splitting, used for processing"""
|
"""Number of chunks after splitting, used for processing"""
|
||||||
error: Optional[str] = None
|
error: str | None = None
|
||||||
"""Error message if failed"""
|
"""Error message if failed"""
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Additional metadata"""
|
"""Additional metadata"""
|
||||||
@@ -220,20 +249,10 @@ class DocStatusStorage(BaseKVStorage):
|
|||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_docs_by_status(
|
||||||
"""Get all failed documents"""
|
self, status: DocStatus
|
||||||
raise NotImplementedError
|
) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all documents with a specific status"""
|
||||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all pending documents"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all processing documents"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all procesed documents"""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
async def update_doc_status(self, data: dict[str, Any]) -> None:
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chromadb import HttpClient
|
from chromadb import HttpClient, PersistentClient
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
@@ -49,6 +49,16 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
**user_collection_settings,
|
**user_collection_settings,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
local_path = config.get("local_path", None)
|
||||||
|
if local_path:
|
||||||
|
self._client = PersistentClient(
|
||||||
|
path=local_path,
|
||||||
|
settings=Settings(
|
||||||
|
allow_reset=True,
|
||||||
|
anonymized_telemetry=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
auth_provider = config.get(
|
auth_provider = config.get(
|
||||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||||
)
|
)
|
||||||
@@ -57,7 +67,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
if "token_authn" in auth_provider:
|
if "token_authn" in auth_provider:
|
||||||
headers = {
|
headers = {
|
||||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
config.get(
|
||||||
|
"auth_header_name", "X-Chroma-Token"
|
||||||
|
): auth_credentials
|
||||||
}
|
}
|
||||||
elif "basic_authn" in auth_provider:
|
elif "basic_authn" in auth_provider:
|
||||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||||
@@ -144,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
results = self._collection.query(
|
results = self._collection.query(
|
||||||
query_embeddings=embedding.tolist(),
|
query_embeddings=embedding.tolist()
|
||||||
|
if not isinstance(embedding, list)
|
||||||
|
else embedding,
|
||||||
n_results=top_k * 2, # Request more results to allow for filtering
|
n_results=top_k * 2, # Request more results to allow for filtering
|
||||||
include=["metadatas", "distances", "documents"],
|
include=["metadatas", "distances", "documents"],
|
||||||
)
|
)
|
||||||
|
@@ -27,8 +27,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Grab config values if available
|
# Grab config values if available
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||||
await self.delete([entity_id])
|
await self.delete([entity_id])
|
||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str):
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Delete relations for a given entity by scanning metadata.
|
Delete relations for a given entity by scanning metadata.
|
||||||
"""
|
"""
|
||||||
|
@@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
counts[doc["status"]] += 1
|
counts[doc["status"]] += 1
|
||||||
return counts
|
return counts
|
||||||
|
|
||||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_docs_by_status(
|
||||||
"""Get all failed documents"""
|
self, status: DocStatus
|
||||||
|
) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""all documents with a specific status"""
|
||||||
return {
|
return {
|
||||||
k: DocProcessingStatus(**v)
|
k: DocProcessingStatus(**v)
|
||||||
for k, v in self._data.items()
|
for k, v in self._data.items()
|
||||||
if v["status"] == DocStatus.FAILED
|
if v["status"] == status
|
||||||
}
|
|
||||||
|
|
||||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all pending documents"""
|
|
||||||
return {
|
|
||||||
k: DocProcessingStatus(**v)
|
|
||||||
for k, v in self._data.items()
|
|
||||||
if v["status"] == DocStatus.PENDING
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all processed documents"""
|
|
||||||
return {
|
|
||||||
k: DocProcessingStatus(**v)
|
|
||||||
for k, v in self._data.items()
|
|
||||||
if v["status"] == DocStatus.PROCESSED
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all processing documents"""
|
|
||||||
return {
|
|
||||||
k: DocProcessingStatus(**v)
|
|
||||||
for k, v in self._data.items()
|
|
||||||
if v["status"] == DocStatus.PROCESSING
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
|
@@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
|
async def delete(self, ids: list[str]) -> None:
|
||||||
|
for doc_id in ids:
|
||||||
|
self._data.pop(doc_id, None)
|
||||||
|
await self.index_done_callback()
|
||||||
|
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import configparser
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
import asyncio
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
pm.install("pymongo")
|
pm.install("pymongo")
|
||||||
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
|
|||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
|
BaseVectorStorage,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
uri = os.environ.get(
|
||||||
os.environ.get(
|
|
||||||
"MONGO_URI",
|
"MONGO_URI",
|
||||||
config.get(
|
config.get(
|
||||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
client = AsyncIOMotorClient(uri)
|
||||||
database = client.get_database(
|
database = client.get_database(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._data = database.get_collection(self.namespace)
|
|
||||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
self._collection_name = self.namespace
|
||||||
|
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||||
str(x["_id"])
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
return data - existing_ids
|
||||||
]
|
|
||||||
return set([s for s in data if s not in existing_ids])
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
|
update_tasks = []
|
||||||
for mode, items in data.items():
|
for mode, items in data.items():
|
||||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
for k, v in items.items():
|
||||||
key = f"{mode}_{k}"
|
key = f"{mode}_{k}"
|
||||||
result = self._data.update_one(
|
data[mode][k]["_id"] = f"{mode}_{k}"
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one(
|
||||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||||
)
|
)
|
||||||
if result.upserted_id:
|
)
|
||||||
logger.debug(f"\nInserted new document with key: {key}")
|
await asyncio.gather(*update_tasks)
|
||||||
data[mode][k]["_id"] = key
|
|
||||||
else:
|
else:
|
||||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
update_tasks = []
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
for k, v in data.items():
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
res = {}
|
res = {}
|
||||||
v = self._data.find_one({"_id": mode + "_" + id})
|
v = await self._data.find_one({"_id": mode + "_" + id})
|
||||||
if v:
|
if v:
|
||||||
res[id] = v
|
res[id] = v
|
||||||
logger.debug(f"llm_response_cache find one by:{id}")
|
logger.debug(f"llm_response_cache find one by:{id}")
|
||||||
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
uri = os.environ.get(
|
||||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
client = AsyncIOMotorClient(uri)
|
||||||
self._data = database.get_collection(self.namespace)
|
database = client.get_database(
|
||||||
logger.info(f"Use MongoDB as doc status {self.namespace}")
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||||
str(x["_id"])
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
return data - existing_ids
|
||||||
]
|
|
||||||
return set([s for s in data if s not in existing_ids])
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
update_tasks = []
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the collection"""
|
"""Drop the collection"""
|
||||||
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
||||||
result = list(self._data.aggregate(pipeline))
|
cursor = self._data.aggregate(pipeline)
|
||||||
|
result = await cursor.to_list()
|
||||||
counts = {}
|
counts = {}
|
||||||
for doc in result:
|
for doc in result:
|
||||||
counts[doc["_id"]] = doc["count"]
|
counts[doc["_id"]] = doc["count"]
|
||||||
@@ -141,8 +175,9 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents by status"""
|
"""Get all documents with a specific status"""
|
||||||
result = list(self._data.find({"status": status.value}))
|
cursor = self._data.find({"status": status.value})
|
||||||
|
result = await cursor.to_list()
|
||||||
return {
|
return {
|
||||||
doc["_id"]: DocProcessingStatus(
|
doc["_id"]: DocProcessingStatus(
|
||||||
content=doc["content"],
|
content=doc["content"],
|
||||||
@@ -156,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
for doc in result
|
for doc in result
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all failed documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.FAILED)
|
|
||||||
|
|
||||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all pending documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.PENDING)
|
|
||||||
|
|
||||||
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all processing documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.PROCESSING)
|
|
||||||
|
|
||||||
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all procesed documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoGraphStorage(BaseGraphStorage):
|
class MongoGraphStorage(BaseGraphStorage):
|
||||||
@@ -185,26 +204,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
global_config=global_config,
|
global_config=global_config,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
self.client = AsyncIOMotorClient(
|
uri = os.environ.get(
|
||||||
os.environ.get(
|
|
||||||
"MONGO_URI",
|
"MONGO_URI",
|
||||||
config.get(
|
config.get(
|
||||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
client = AsyncIOMotorClient(uri)
|
||||||
self.db = self.client[
|
database = client.get_database(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
]
|
|
||||||
self.collection = self.db[
|
|
||||||
os.environ.get(
|
|
||||||
"MONGO_KG_COLLECTION",
|
|
||||||
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self.collection = database.get_collection(self._collection_name)
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -451,7 +471,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
self, source_node_id: str
|
self, source_node_id: str
|
||||||
) -> Union[List[Tuple[str, str]], None]:
|
) -> Union[List[Tuple[str, str]], None]:
|
||||||
"""
|
"""
|
||||||
Return a list of (target_id, relation) for direct edges from source_node_id.
|
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||||
"""
|
"""
|
||||||
pipeline = [
|
pipeline = [
|
||||||
@@ -475,7 +495,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
edges = result[0].get("edges", [])
|
edges = result[0].get("edges", [])
|
||||||
return [(e["target"], e["relation"]) for e in edges]
|
return [(source_node_id, e["target"]) for e in edges]
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -522,7 +542,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str):
|
||||||
"""
|
"""
|
||||||
1) Remove node’s doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound edges from any doc that references node_id.
|
||||||
"""
|
"""
|
||||||
# Remove inbound edges from all other docs
|
# Remove inbound edges from all other docs
|
||||||
@@ -542,3 +562,359 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
Placeholder for demonstration, raises NotImplementedError.
|
Placeholder for demonstration, raises NotImplementedError.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
|
||||||
|
#
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# QUERY
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all existing node _id in the database
|
||||||
|
Returns:
|
||||||
|
[id1, id2, ...] # Alphabetically sorted id list
|
||||||
|
"""
|
||||||
|
# Use MongoDB's distinct and aggregation to get all unique labels
|
||||||
|
pipeline = [
|
||||||
|
{"$group": {"_id": "$_id"}}, # Group by _id
|
||||||
|
{"$sort": {"_id": 1}}, # Sort alphabetically
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = self.collection.aggregate(pipeline)
|
||||||
|
labels = []
|
||||||
|
async for doc in cursor:
|
||||||
|
labels.append(doc["_id"])
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_label: Label of the nodes to start from
|
||||||
|
max_depth: Maximum depth of traversal (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||||
|
"""
|
||||||
|
label = node_label
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if label == "*":
|
||||||
|
# Get all nodes and edges
|
||||||
|
async for node_doc in self.collection.find({}):
|
||||||
|
node_id = str(node_doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_doc.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in node_doc.items()
|
||||||
|
if k not in ["_id", "edges"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Process edges
|
||||||
|
for edge in node_doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
else:
|
||||||
|
# Verify if starting node exists
|
||||||
|
start_nodes = self.collection.find({"_id": label})
|
||||||
|
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||||
|
if not start_nodes_exist:
|
||||||
|
logger.warning(f"Starting node with label {label} does not exist!")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Use $graphLookup for traversal
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {"_id": label}
|
||||||
|
}, # Start with nodes having the specified label
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": self._collection_name,
|
||||||
|
"startWith": "$edges.target",
|
||||||
|
"connectFromField": "edges.target",
|
||||||
|
"connectToField": "_id",
|
||||||
|
"maxDepth": max_depth,
|
||||||
|
"depthField": "depth",
|
||||||
|
"as": "connected_nodes",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async for doc in self.collection.aggregate(pipeline):
|
||||||
|
# Add the start node
|
||||||
|
node_id = str(doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[
|
||||||
|
doc.get(
|
||||||
|
"_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in doc.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"edges",
|
||||||
|
"connected_nodes",
|
||||||
|
"depth",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from start node
|
||||||
|
for edge in doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
# Add connected nodes and their edges
|
||||||
|
for connected in doc.get("connected_nodes", []):
|
||||||
|
node_id = str(connected["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[connected.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in connected.items()
|
||||||
|
if k not in ["_id", "edges", "depth"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from connected nodes
|
||||||
|
for edge in connected.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"MongoDB query failed: {str(e)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
|
uri = os.environ.get(
|
||||||
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
client = AsyncIOMotorClient(uri)
|
||||||
|
database = client.get_database(
|
||||||
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
|
# Ensure vector index exists
|
||||||
|
self.create_vector_index(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
|
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
||||||
|
"""Creates an Atlas Vector Search index."""
|
||||||
|
client = MongoClient(uri)
|
||||||
|
collection = client.get_database(database_name).get_collection(
|
||||||
|
self._collection_name
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_index_model = SearchIndexModel(
|
||||||
|
definition={
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"type": "vector",
|
||||||
|
"numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
|
||||||
|
"path": "vector",
|
||||||
|
"similarity": "cosine", # Options: euclidean, cosine, dotProduct
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
name="vector_knn_index",
|
||||||
|
type="vectorSearch",
|
||||||
|
)
|
||||||
|
|
||||||
|
collection.create_search_index(search_index_model)
|
||||||
|
logger.info("Vector index created successfully.")
|
||||||
|
|
||||||
|
except PyMongoError as _:
|
||||||
|
logger.debug("vector index already exist")
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
|
if not data:
|
||||||
|
logger.warning("You are inserting an empty data set to vector DB")
|
||||||
|
return []
|
||||||
|
|
||||||
|
list_data = [
|
||||||
|
{
|
||||||
|
"_id": k,
|
||||||
|
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||||
|
}
|
||||||
|
for k, v in data.items()
|
||||||
|
]
|
||||||
|
contents = [v["content"] for v in data.values()]
|
||||||
|
batches = [
|
||||||
|
contents[i : i + self._max_batch_size]
|
||||||
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def wrapped_task(batch):
|
||||||
|
result = await self.embedding_func(batch)
|
||||||
|
pbar.update(1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||||
|
pbar = tqdm_async(
|
||||||
|
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||||
|
|
||||||
|
embeddings = np.concatenate(embeddings_list)
|
||||||
|
for i, d in enumerate(list_data):
|
||||||
|
d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
|
||||||
|
|
||||||
|
update_tasks = []
|
||||||
|
for doc in list_data:
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
|
return list_data
|
||||||
|
|
||||||
|
async def query(self, query, top_k=5):
|
||||||
|
"""Queries the vector database using Atlas Vector Search."""
|
||||||
|
# Generate the embedding
|
||||||
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
|
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||||
|
query_vector = embedding[0].tolist()
|
||||||
|
|
||||||
|
# Define the aggregation pipeline with the converted query vector
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$vectorSearch": {
|
||||||
|
"index": "vector_knn_index", # Ensure this matches the created index name
|
||||||
|
"path": "vector",
|
||||||
|
"queryVector": query_vector,
|
||||||
|
"numCandidates": 100, # Adjust for performance
|
||||||
|
"limit": top_k,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
|
||||||
|
{"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
|
||||||
|
{"$project": {"vector": 0}},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the aggregation pipeline
|
||||||
|
cursor = self._data.aggregate(pipeline)
|
||||||
|
results = await cursor.to_list()
|
||||||
|
|
||||||
|
# Format and return the results
|
||||||
|
return [
|
||||||
|
{**doc, "id": doc["_id"], "distance": doc.get("score", None)}
|
||||||
|
for doc in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||||
|
"""Check if the collection exists. if not, create it."""
|
||||||
|
client = MongoClient(uri)
|
||||||
|
database = client.get_database(database_name)
|
||||||
|
|
||||||
|
collection_names = database.list_collection_names()
|
||||||
|
|
||||||
|
if collection_name not in collection_names:
|
||||||
|
database.create_collection(collection_name)
|
||||||
|
logger.info(f"Created collection: {collection_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Collection '{collection_name}' already exists.")
|
||||||
|
@@ -79,8 +79,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
# Initialize lock only for file operations
|
# Initialize lock only for file operations
|
||||||
self._save_lock = asyncio.Lock()
|
self._save_lock = asyncio.Lock()
|
||||||
# Use global config value if specified, otherwise use default
|
# Use global config value if specified, otherwise use default
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
if cosine_threshold is None:
|
if cosine_threshold is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str):
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
relations = [
|
relations = [
|
||||||
dp
|
dp
|
||||||
|
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
print("KG successfully indexed.")
|
print("KG successfully indexed.")
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def _label_exists(self, label: str) -> bool:
|
||||||
entity_name_label = node_id.strip('"')
|
"""Check if a label exists in the Neo4j database."""
|
||||||
|
query = "CALL db.labels() YIELD label RETURN label"
|
||||||
|
try:
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
result = await session.run(query)
|
||||||
|
labels = [record["label"] for record in await result.data()]
|
||||||
|
return label in labels
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking label existence: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _ensure_label(self, label: str) -> str:
|
||||||
|
"""Ensure a label exists by validating it."""
|
||||||
|
clean_label = label.strip('"')
|
||||||
|
if not await self._label_exists(clean_label):
|
||||||
|
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
|
||||||
|
return clean_label
|
||||||
|
|
||||||
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
entity_name_label = await self._ensure_label(node_id)
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||||
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
|
"""Get node by its label identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The node label to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Node properties if found
|
||||||
|
None: If node not found
|
||||||
|
"""
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = await self._ensure_label(node_id)
|
||||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
@@ -226,18 +253,20 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
"""Find edge between two nodes identified by their labels.
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node_label (str): Label of the source nodes
|
source_node_id (str): Label of the source node
|
||||||
target_node_label (str): Label of the target nodes
|
target_node_id (str): Label of the target node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: List of all relationships/edges found
|
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||||
|
None: If error occurs
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||||
@@ -250,14 +279,47 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
if record:
|
if record and "edge_properties" in record:
|
||||||
|
try:
|
||||||
result = dict(record["edge_properties"])
|
result = dict(record["edge_properties"])
|
||||||
|
# Ensure required keys exist with defaults
|
||||||
|
required_keys = {
|
||||||
|
"weight": 0.0,
|
||||||
|
"source_id": None,
|
||||||
|
"target_id": None,
|
||||||
|
}
|
||||||
|
for key, default_value in required_keys.items():
|
||||||
|
if key not in result:
|
||||||
|
result[key] = default_value
|
||||||
|
logger.warning(
|
||||||
|
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||||
|
f"missing {key}, using default: {default_value}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
else:
|
except (KeyError, TypeError, ValueError) as e:
|
||||||
return None
|
logger.error(
|
||||||
|
f"Error processing edge properties between {entity_name_label_source} "
|
||||||
|
f"and {entity_name_label_target}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||||
|
)
|
||||||
|
# Return default edge properties when no edge found
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
node_id: The unique identifier for the node (used as label)
|
node_id: The unique identifier for the node (used as label)
|
||||||
node_data: Dictionary of node properties
|
node_data: Dictionary of node properties
|
||||||
"""
|
"""
|
||||||
label = node_id.strip('"')
|
label = await self._ensure_label(node_id)
|
||||||
properties = node_data
|
properties = node_data
|
||||||
|
|
||||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||||
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
neo4jExceptions.WriteServiceUnavailable,
|
neo4jExceptions.WriteServiceUnavailable,
|
||||||
|
neo4jExceptions.ClientError,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id (str): Label of the target node (used as identifier)
|
||||||
edge_data (dict): Dictionary of properties to set on the edge
|
edge_data (dict): Dictionary of properties to set on the edge
|
||||||
"""
|
"""
|
||||||
source_node_label = source_node_id.strip('"')
|
source_label = await self._ensure_label(source_node_id)
|
||||||
target_node_label = target_node_id.strip('"')
|
target_label = await self._ensure_label(target_node_id)
|
||||||
edge_properties = edge_data
|
edge_properties = edge_data
|
||||||
|
|
||||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (source:`{source_node_label}`)
|
MATCH (source:`{source_label}`)
|
||||||
WITH source
|
WITH source
|
||||||
MATCH (target:`{target_node_label}`)
|
MATCH (target:`{target_label}`)
|
||||||
MERGE (source)-[r:DIRECTED]->(target)
|
MERGE (source)-[r:DIRECTED]->(target)
|
||||||
SET r += $properties
|
SET r += $properties
|
||||||
RETURN r
|
RETURN r
|
||||||
"""
|
"""
|
||||||
await tx.run(query, properties=edge_properties)
|
result = await tx.run(query, properties=edge_properties)
|
||||||
|
record = await result.single()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
|
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> Dict[str, DocProcessingStatus]:
|
) -> Dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents by status"""
|
"""all documents with a specific status"""
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
||||||
params = {"workspace": self.db.workspace, "status": status}
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
result = await self.db.query(sql, params, True)
|
result = await self.db.query(sql, params, True)
|
||||||
@@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
for element in result
|
for element in result
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all failed documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.FAILED)
|
|
||||||
|
|
||||||
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all pending documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.PENDING)
|
|
||||||
|
|
||||||
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all processing documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.PROCESSING)
|
|
||||||
|
|
||||||
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
|
||||||
"""Get all procesed documents"""
|
|
||||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
||||||
logger.info("Doc status had been saved into postgresql db!")
|
logger.info("Doc status had been saved into postgresql db!")
|
||||||
|
@@ -1,10 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import configparser
|
import configparser
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Optional, Type, Union, cast
|
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|||||||
"FaissVectorDBStorage",
|
"FaissVectorDBStorage",
|
||||||
"QdrantVectorDBStorage",
|
"QdrantVectorDBStorage",
|
||||||
"OracleVectorDBStorage",
|
"OracleVectorDBStorage",
|
||||||
|
"MongoVectorDBStorage",
|
||||||
],
|
],
|
||||||
"required_methods": ["query", "upsert"],
|
"required_methods": ["query", "upsert"],
|
||||||
},
|
},
|
||||||
@@ -86,12 +89,12 @@ STORAGE_IMPLEMENTATIONS = {
|
|||||||
"PGDocStatusStorage",
|
"PGDocStatusStorage",
|
||||||
"MongoDocStatusStorage",
|
"MongoDocStatusStorage",
|
||||||
],
|
],
|
||||||
"required_methods": ["get_pending_docs"],
|
"required_methods": ["get_docs_by_status"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Storage implementation environment variable without default value
|
# Storage implementation environment variable without default value
|
||||||
STORAGE_ENV_REQUIREMENTS = {
|
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||||
# KV Storage Implementations
|
# KV Storage Implementations
|
||||||
"JsonKVStorage": [],
|
"JsonKVStorage": [],
|
||||||
"MongoKVStorage": [],
|
"MongoKVStorage": [],
|
||||||
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
|
|||||||
"ORACLE_PASSWORD",
|
"ORACLE_PASSWORD",
|
||||||
"ORACLE_CONFIG_DIR",
|
"ORACLE_CONFIG_DIR",
|
||||||
],
|
],
|
||||||
|
"MongoVectorDBStorage": [],
|
||||||
# Document Status Storage Implementations
|
# Document Status Storage Implementations
|
||||||
"JsonDocStatusStorage": [],
|
"JsonDocStatusStorage": [],
|
||||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
@@ -160,6 +164,7 @@ STORAGES = {
|
|||||||
"MongoKVStorage": ".kg.mongo_impl",
|
"MongoKVStorage": ".kg.mongo_impl",
|
||||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||||
"MongoGraphStorage": ".kg.mongo_impl",
|
"MongoGraphStorage": ".kg.mongo_impl",
|
||||||
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||||
"RedisKVStorage": ".kg.redis_impl",
|
"RedisKVStorage": ".kg.redis_impl",
|
||||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||||
"TiDBKVStorage": ".kg.tidb_impl",
|
"TiDBKVStorage": ".kg.tidb_impl",
|
||||||
@@ -176,7 +181,7 @@ STORAGES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def lazy_external_import(module_name: str, class_name: str):
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||||
"""Lazily import a class from an external module based on the package of the caller."""
|
"""Lazily import a class from an external module based on the package of the caller."""
|
||||||
# Get the caller's module and package
|
# Get the caller's module and package
|
||||||
import inspect
|
import inspect
|
||||||
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
|
|||||||
module = inspect.getmodule(caller_frame)
|
module = inspect.getmodule(caller_frame)
|
||||||
package = module.__package__ if module else None
|
package = module.__package__ if module else None
|
||||||
|
|
||||||
def import_class(*args, **kwargs):
|
def import_class(*args: Any, **kwargs: Any):
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
module = importlib.import_module(module_name, package=package)
|
module = importlib.import_module(module_name, package=package)
|
||||||
@@ -225,7 +230,7 @@ class LightRAG:
|
|||||||
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
||||||
|
|
||||||
working_dir: str = field(
|
working_dir: str = field(
|
||||||
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||||
)
|
)
|
||||||
"""Directory where cache and temporary files are stored."""
|
"""Directory where cache and temporary files are stored."""
|
||||||
|
|
||||||
@@ -302,7 +307,7 @@ class LightRAG:
|
|||||||
- random_seed: Seed value for reproducibility.
|
- random_seed: Seed value for reproducibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding_func: EmbeddingFunc = None
|
embedding_func: EmbeddingFunc | None = None
|
||||||
"""Function for computing text embeddings. Must be set before use."""
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
@@ -312,7 +317,7 @@ class LightRAG:
|
|||||||
"""Maximum number of concurrent embedding function calls."""
|
"""Maximum number of concurrent embedding function calls."""
|
||||||
|
|
||||||
# LLM Configuration
|
# LLM Configuration
|
||||||
llm_model_func: callable = None
|
llm_model_func: Callable[..., object] | None = None
|
||||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||||
|
|
||||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
@@ -342,10 +347,8 @@ class LightRAG:
|
|||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Dictionary for additional parameters and extensions."""
|
|
||||||
|
|
||||||
# extension
|
"""Dictionary for additional parameters and extensions."""
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
|
||||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||||
convert_response_to_json
|
convert_response_to_json
|
||||||
)
|
)
|
||||||
@@ -354,7 +357,7 @@ class LightRAG:
|
|||||||
chunking_func: Callable[
|
chunking_func: Callable[
|
||||||
[
|
[
|
||||||
str,
|
str,
|
||||||
Optional[str],
|
str | None,
|
||||||
bool,
|
bool,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
@@ -443,77 +446,74 @@ class LightRAG:
|
|||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Init LLM
|
# Init LLM
|
||||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||||
self.embedding_func
|
self.embedding_func
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize all storages
|
# Initialize all storages
|
||||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||||
self._get_storage_class(self.kv_storage)
|
self._get_storage_class(self.kv_storage)
|
||||||
)
|
) # type: ignore
|
||||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
|
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
||||||
self.vector_storage
|
self.vector_storage
|
||||||
)
|
) # type: ignore
|
||||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
|
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
||||||
self.graph_storage
|
self.graph_storage
|
||||||
)
|
) # type: ignore
|
||||||
|
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||||
self.key_string_value_json_storage_cls = partial(
|
|
||||||
self.key_string_value_json_storage_cls, global_config=global_config
|
self.key_string_value_json_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
self.vector_db_storage_cls = partial( # type: ignore
|
||||||
self.vector_db_storage_cls = partial(
|
|
||||||
self.vector_db_storage_cls, global_config=global_config
|
self.vector_db_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
self.graph_storage_cls = partial( # type: ignore
|
||||||
self.graph_storage_cls = partial(
|
|
||||||
self.graph_storage_cls, global_config=global_config
|
self.graph_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize document status storage
|
# Initialize document status storage
|
||||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||||
|
|
||||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entities_vdb = self.vector_db_storage_cls(
|
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"entity_name"},
|
meta_fields={"entity_name"},
|
||||||
)
|
)
|
||||||
self.relationships_vdb = self.vector_db_storage_cls(
|
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"src_id", "tgt_id"},
|
meta_fields={"src_id", "tgt_id"},
|
||||||
)
|
)
|
||||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||||
),
|
),
|
||||||
@@ -527,13 +527,12 @@ class LightRAG:
|
|||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# What's for, Is this nessisary ?
|
|
||||||
if self.llm_response_cache and hasattr(
|
if self.llm_response_cache and hasattr(
|
||||||
self.llm_response_cache, "global_config"
|
self.llm_response_cache, "global_config"
|
||||||
):
|
):
|
||||||
hashing_kv = self.llm_response_cache
|
hashing_kv = self.llm_response_cache
|
||||||
else:
|
else:
|
||||||
hashing_kv = self.key_string_value_json_storage_cls(
|
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
@@ -542,7 +541,7 @@ class LightRAG:
|
|||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
self.llm_model_func,
|
self.llm_model_func, # type: ignore
|
||||||
hashing_kv=hashing_kv,
|
hashing_kv=hashing_kv,
|
||||||
**self.llm_model_kwargs,
|
**self.llm_model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -559,68 +558,45 @@ class LightRAG:
|
|||||||
node_label=nodel_label, max_depth=max_depth
|
node_label=nodel_label, max_depth=max_depth
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> dict:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
import_path = STORAGES[storage_name]
|
import_path = STORAGES[storage_name]
|
||||||
storage_class = lazy_external_import(import_path, storage_name)
|
storage_class = lazy_external_import(import_path, storage_name)
|
||||||
return storage_class
|
return storage_class
|
||||||
|
|
||||||
def set_storage_client(self, db_client):
|
|
||||||
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
|
||||||
# Inject db to storage implementation (only tested on Oracle Database)
|
|
||||||
for storage in [
|
|
||||||
self.vector_db_storage_cls,
|
|
||||||
self.graph_storage_cls,
|
|
||||||
self.doc_status,
|
|
||||||
self.full_docs,
|
|
||||||
self.text_chunks,
|
|
||||||
self.llm_response_cache,
|
|
||||||
self.key_string_value_json_storage_cls,
|
|
||||||
self.chunks_vdb,
|
|
||||||
self.relationships_vdb,
|
|
||||||
self.entities_vdb,
|
|
||||||
self.graph_storage_cls,
|
|
||||||
self.chunk_entity_relation_graph,
|
|
||||||
self.llm_response_cache,
|
|
||||||
]:
|
|
||||||
# set client
|
|
||||||
storage.db = db_client
|
|
||||||
|
|
||||||
def insert(
|
def insert(
|
||||||
self,
|
self,
|
||||||
string_or_strings: Union[str, list[str]],
|
input: str | list[str],
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Sync Insert documents with checkpoint support
|
"""Sync Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_or_strings: Single document string or list of document strings
|
input: Single document string or list of document strings
|
||||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||||
chunk_size, split the sub chunk by token size.
|
|
||||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||||
split_by_character is None, this parameter is ignored.
|
split_by_character is None, this parameter is ignored.
|
||||||
"""
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
self.ainsert(input, split_by_character, split_by_character_only)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ainsert(
|
async def ainsert(
|
||||||
self,
|
self,
|
||||||
string_or_strings: Union[str, list[str]],
|
input: str | list[str],
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Async Insert documents with checkpoint support
|
"""Async Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_or_strings: Single document string or list of document strings
|
input: Single document string or list of document strings
|
||||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||||
chunk_size, split the sub chunk by token size.
|
|
||||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||||
split_by_character is None, this parameter is ignored.
|
split_by_character is None, this parameter is ignored.
|
||||||
"""
|
"""
|
||||||
await self.apipeline_enqueue_documents(string_or_strings)
|
await self.apipeline_enqueue_documents(input)
|
||||||
await self.apipeline_process_enqueue_documents(
|
await self.apipeline_process_enqueue_documents(
|
||||||
split_by_character, split_by_character_only
|
split_by_character, split_by_character_only
|
||||||
)
|
)
|
||||||
@@ -677,7 +653,7 @@ class LightRAG:
|
|||||||
if update_storage:
|
if update_storage:
|
||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
|
async def apipeline_enqueue_documents(self, input: str | list[str]):
|
||||||
"""
|
"""
|
||||||
Pipeline for Processing Documents
|
Pipeline for Processing Documents
|
||||||
|
|
||||||
@@ -686,11 +662,11 @@ class LightRAG:
|
|||||||
3. Filter out already processed documents
|
3. Filter out already processed documents
|
||||||
4. Enqueue document in status
|
4. Enqueue document in status
|
||||||
"""
|
"""
|
||||||
if isinstance(string_or_strings, str):
|
if isinstance(input, str):
|
||||||
string_or_strings = [string_or_strings]
|
input = [input]
|
||||||
|
|
||||||
# 1. Remove duplicate contents from the list
|
# 1. Remove duplicate contents from the list
|
||||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
unique_contents = list(set(doc.strip() for doc in input))
|
||||||
|
|
||||||
# 2. Generate document IDs and initial status
|
# 2. Generate document IDs and initial status
|
||||||
new_docs: dict[str, Any] = {
|
new_docs: dict[str, Any] = {
|
||||||
@@ -739,11 +715,11 @@ class LightRAG:
|
|||||||
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
||||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||||
|
|
||||||
processing_docs = await self.doc_status.get_processing_docs()
|
processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
|
||||||
to_process_docs.update(processing_docs)
|
to_process_docs.update(processing_docs)
|
||||||
failed_docs = await self.doc_status.get_failed_docs()
|
failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
|
||||||
to_process_docs.update(failed_docs)
|
to_process_docs.update(failed_docs)
|
||||||
pendings_docs = await self.doc_status.get_pending_docs()
|
pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
|
||||||
to_process_docs.update(pendings_docs)
|
to_process_docs.update(pendings_docs)
|
||||||
|
|
||||||
if not to_process_docs:
|
if not to_process_docs:
|
||||||
@@ -857,8 +833,9 @@ class LightRAG:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _insert_done(self):
|
async def _insert_done(self):
|
||||||
tasks = []
|
tasks = [
|
||||||
for storage_inst in [
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||||
|
for storage_inst in [ # type: ignore
|
||||||
self.full_docs,
|
self.full_docs,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
self.llm_response_cache,
|
self.llm_response_cache,
|
||||||
@@ -866,23 +843,22 @@ class LightRAG:
|
|||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
self.chunks_vdb,
|
self.chunks_vdb,
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
]:
|
]
|
||||||
if storage_inst is None:
|
if storage_inst is not None
|
||||||
continue
|
]
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def insert_custom_kg(self, custom_kg: dict):
|
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||||
|
|
||||||
async def ainsert_custom_kg(self, custom_kg: dict):
|
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||||
update_storage = False
|
update_storage = False
|
||||||
try:
|
try:
|
||||||
# Insert chunks into vector storage
|
# Insert chunks into vector storage
|
||||||
all_chunks_data = {}
|
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||||
chunk_to_source_map = {}
|
chunk_to_source_map: dict[str, str] = {}
|
||||||
for chunk_data in custom_kg.get("chunks", []):
|
for chunk_data in custom_kg.get("chunks", {}):
|
||||||
chunk_content = chunk_data["content"]
|
chunk_content = chunk_data["content"]
|
||||||
source_id = chunk_data["source_id"]
|
source_id = chunk_data["source_id"]
|
||||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||||
@@ -892,13 +868,13 @@ class LightRAG:
|
|||||||
chunk_to_source_map[source_id] = chunk_id
|
chunk_to_source_map[source_id] = chunk_id
|
||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
if self.chunks_vdb is not None and all_chunks_data:
|
if all_chunks_data:
|
||||||
await self.chunks_vdb.upsert(all_chunks_data)
|
await self.chunks_vdb.upsert(all_chunks_data)
|
||||||
if self.text_chunks is not None and all_chunks_data:
|
if all_chunks_data:
|
||||||
await self.text_chunks.upsert(all_chunks_data)
|
await self.text_chunks.upsert(all_chunks_data)
|
||||||
|
|
||||||
# Insert entities into knowledge graph
|
# Insert entities into knowledge graph
|
||||||
all_entities_data = []
|
all_entities_data: list[dict[str, str]] = []
|
||||||
for entity_data in custom_kg.get("entities", []):
|
for entity_data in custom_kg.get("entities", []):
|
||||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||||
@@ -914,7 +890,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare node data
|
# Prepare node data
|
||||||
node_data = {
|
node_data: dict[str, str] = {
|
||||||
"entity_type": entity_type,
|
"entity_type": entity_type,
|
||||||
"description": description,
|
"description": description,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -928,7 +904,7 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
# Insert relationships into knowledge graph
|
# Insert relationships into knowledge graph
|
||||||
all_relationships_data = []
|
all_relationships_data: list[dict[str, str]] = []
|
||||||
for relationship_data in custom_kg.get("relationships", []):
|
for relationship_data in custom_kg.get("relationships", []):
|
||||||
src_id = f'"{relationship_data["src_id"].upper()}"'
|
src_id = f'"{relationship_data["src_id"].upper()}"'
|
||||||
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
||||||
@@ -970,7 +946,7 @@ class LightRAG:
|
|||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
edge_data = {
|
edge_data: dict[str, str] = {
|
||||||
"src_id": src_id,
|
"src_id": src_id,
|
||||||
"tgt_id": tgt_id,
|
"tgt_id": tgt_id,
|
||||||
"description": description,
|
"description": description,
|
||||||
@@ -980,7 +956,6 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
# Insert entities into vector storage if needed
|
# Insert entities into vector storage if needed
|
||||||
if self.entities_vdb is not None:
|
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
"content": dp["entity_name"] + dp["description"],
|
"content": dp["entity_name"] + dp["description"],
|
||||||
@@ -991,7 +966,6 @@ class LightRAG:
|
|||||||
await self.entities_vdb.upsert(data_for_vdb)
|
await self.entities_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
# Insert relationships into vector storage if needed
|
# Insert relationships into vector storage if needed
|
||||||
if self.relationships_vdb is not None:
|
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
@@ -1004,17 +978,46 @@ class LightRAG:
|
|||||||
for dp in all_relationships_data
|
for dp in all_relationships_data
|
||||||
}
|
}
|
||||||
await self.relationships_vdb.upsert(data_for_vdb)
|
await self.relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if update_storage:
|
if update_storage:
|
||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
|
def query(
|
||||||
|
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
||||||
|
) -> str | Iterator[str]:
|
||||||
|
"""
|
||||||
|
Perform a sync query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the query execution.
|
||||||
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.aquery(query, prompt, param))
|
|
||||||
|
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||||
|
|
||||||
async def aquery(
|
async def aquery(
|
||||||
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
|
self,
|
||||||
):
|
query: str,
|
||||||
|
param: QueryParam = QueryParam(),
|
||||||
|
prompt: str | None = None,
|
||||||
|
) -> str | AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Perform a async query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the query execution.
|
||||||
|
"""
|
||||||
if param.mode in ["local", "global", "hybrid"]:
|
if param.mode in ["local", "global", "hybrid"]:
|
||||||
response = await kg_query(
|
response = await kg_query(
|
||||||
query,
|
query,
|
||||||
@@ -1094,7 +1097,7 @@ class LightRAG:
|
|||||||
|
|
||||||
async def aquery_with_separate_keyword_extraction(
|
async def aquery_with_separate_keyword_extraction(
|
||||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||||
):
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||||
@@ -1117,8 +1120,8 @@ class LightRAG:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
param.hl_keywords = (hl_keywords,)
|
param.hl_keywords = hl_keywords
|
||||||
param.ll_keywords = (ll_keywords,)
|
param.ll_keywords = ll_keywords
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# STEP 2: Final Query Logic
|
# STEP 2: Final Query Logic
|
||||||
@@ -1146,7 +1149,7 @@ class LightRAG:
|
|||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=self.embedding_funcne,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
@@ -1195,12 +1198,7 @@ class LightRAG:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
async def _query_done(self):
|
async def _query_done(self):
|
||||||
tasks = []
|
await self.llm_response_cache.index_done_callback()
|
||||||
for storage_inst in [self.llm_response_cache]:
|
|
||||||
if storage_inst is None:
|
|
||||||
continue
|
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
def delete_by_entity(self, entity_name: str):
|
def delete_by_entity(self, entity_name: str):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
@@ -1222,16 +1220,16 @@ class LightRAG:
|
|||||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||||
|
|
||||||
async def _delete_by_entity_done(self):
|
async def _delete_by_entity_done(self):
|
||||||
tasks = []
|
await asyncio.gather(
|
||||||
for storage_inst in [
|
*[
|
||||||
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||||
|
for storage_inst in [ # type: ignore
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
]:
|
]
|
||||||
if storage_inst is None:
|
]
|
||||||
continue
|
)
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
||||||
"""Get summary of document content
|
"""Get summary of document content
|
||||||
@@ -1256,7 +1254,7 @@ class LightRAG:
|
|||||||
"""
|
"""
|
||||||
return await self.doc_status.get_status_counts()
|
return await self.doc_status.get_status_counts()
|
||||||
|
|
||||||
async def adelete_by_doc_id(self, doc_id: str):
|
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||||
"""Delete a document and all its related data
|
"""Delete a document and all its related data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1273,6 +1271,9 @@ class LightRAG:
|
|||||||
|
|
||||||
# 2. Get all related chunks
|
# 2. Get all related chunks
|
||||||
chunks = await self.text_chunks.get_by_id(doc_id)
|
chunks = await self.text_chunks.get_by_id(doc_id)
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
chunk_ids = list(chunks.keys())
|
chunk_ids = list(chunks.keys())
|
||||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||||
|
|
||||||
@@ -1443,13 +1444,9 @@ class LightRAG:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while deleting document {doc_id}: {e}")
|
logger.error(f"Error while deleting document {doc_id}: {e}")
|
||||||
|
|
||||||
def delete_by_doc_id(self, doc_id: str):
|
|
||||||
"""Synchronous version of adelete"""
|
|
||||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
|
||||||
|
|
||||||
async def get_entity_info(
|
async def get_entity_info(
|
||||||
self, entity_name: str, include_vector_data: bool = False
|
self, entity_name: str, include_vector_data: bool = False
|
||||||
):
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of an entity
|
"""Get detailed information of an entity
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1469,7 +1466,7 @@ class LightRAG:
|
|||||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||||
source_id = node_data.get("source_id") if node_data else None
|
source_id = node_data.get("source_id") if node_data else None
|
||||||
|
|
||||||
result = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
"entity_name": entity_name,
|
"entity_name": entity_name,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
"graph_data": node_data,
|
"graph_data": node_data,
|
||||||
@@ -1483,21 +1480,6 @@ class LightRAG:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
|
||||||
"""Synchronous version of getting entity information
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entity_name: Entity name (no need for quotes)
|
|
||||||
include_vector_data: Whether to include data from the vector database
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import tracemalloc
|
|
||||||
|
|
||||||
tracemalloc.start()
|
|
||||||
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
|
||||||
finally:
|
|
||||||
tracemalloc.stop()
|
|
||||||
|
|
||||||
async def get_relation_info(
|
async def get_relation_info(
|
||||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||||
):
|
):
|
||||||
@@ -1525,7 +1507,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
source_id = edge_data.get("source_id") if edge_data else None
|
source_id = edge_data.get("source_id") if edge_data else None
|
||||||
|
|
||||||
result = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
"src_entity": src_entity,
|
"src_entity": src_entity,
|
||||||
"tgt_entity": tgt_entity,
|
"tgt_entity": tgt_entity,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -1539,23 +1521,3 @@ class LightRAG:
|
|||||||
result["vector_data"] = vector_data[0] if vector_data else None
|
result["vector_data"] = vector_data[0] if vector_data else None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_relation_info_sync(
|
|
||||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
|
||||||
):
|
|
||||||
"""Synchronous version of getting relationship information
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src_entity: Source entity name (no need for quotes)
|
|
||||||
tgt_entity: Target entity name (no need for quotes)
|
|
||||||
include_vector_data: Whether to include data from the vector database
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import tracemalloc
|
|
||||||
|
|
||||||
tracemalloc.start()
|
|
||||||
return asyncio.run(
|
|
||||||
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
tracemalloc.stop()
|
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
from typing import List, Dict, Callable, Any
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Callable, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +25,7 @@ class Model(BaseModel):
|
|||||||
...,
|
...,
|
||||||
description="A function that generates the response from the llm. The response must be a string",
|
description="A function that generates the response from the llm. The response must be a string",
|
||||||
)
|
)
|
||||||
kwargs: Dict[str, Any] = Field(
|
kwargs: dict[str, Any] = Field(
|
||||||
...,
|
...,
|
||||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||||
)
|
)
|
||||||
@@ -57,7 +59,7 @@ class MultiModel:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, models: List[Model]):
|
def __init__(self, models: list[Model]):
|
||||||
self._models = models
|
self._models = models
|
||||||
self._current_model = 0
|
self._current_model = 0
|
||||||
|
|
||||||
@@ -66,7 +68,11 @@ class MultiModel:
|
|||||||
return self._models[self._current_model]
|
return self._models[self._current_model]
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] = [],
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from typing import Any, Union
|
from typing import Any, AsyncIterator
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from .utils import (
|
from .utils import (
|
||||||
logger,
|
logger,
|
||||||
@@ -36,7 +38,7 @@ import time
|
|||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str,
|
content: str,
|
||||||
split_by_character: Union[str, None] = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
overlap_token_size: int = 128,
|
overlap_token_size: int = 128,
|
||||||
max_token_size: int = 1024,
|
max_token_size: int = 1024,
|
||||||
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||||
|
# Handle the case where get_edge returns None or missing fields
|
||||||
|
if already_edge:
|
||||||
|
# Get weight with default 0.0 if missing
|
||||||
|
if "weight" in already_edge:
|
||||||
already_weights.append(already_edge["weight"])
|
already_weights.append(already_edge["weight"])
|
||||||
already_source_ids.extend(
|
else:
|
||||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
logger.warning(
|
||||||
|
f"Edge between {src_id} and {tgt_id} missing weight field"
|
||||||
|
)
|
||||||
|
already_weights.append(0.0)
|
||||||
|
|
||||||
|
# Get source_id with empty string default if missing or None
|
||||||
|
if "source_id" in already_edge and already_edge["source_id"] is not None:
|
||||||
|
already_source_ids.extend(
|
||||||
|
split_string_by_multi_markers(
|
||||||
|
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||||||
)
|
)
|
||||||
already_description.append(already_edge["description"])
|
|
||||||
already_keywords.extend(
|
|
||||||
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get description with empty string default if missing or None
|
||||||
|
if (
|
||||||
|
"description" in already_edge
|
||||||
|
and already_edge["description"] is not None
|
||||||
|
):
|
||||||
|
already_description.append(already_edge["description"])
|
||||||
|
|
||||||
|
# Get keywords with empty string default if missing or None
|
||||||
|
if "keywords" in already_edge and already_edge["keywords"] is not None:
|
||||||
|
already_keywords.extend(
|
||||||
|
split_string_by_multi_markers(
|
||||||
|
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process edges_data with None checks
|
||||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||||
description = GRAPH_FIELD_SEP.join(
|
description = GRAPH_FIELD_SEP.join(
|
||||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
sorted(
|
||||||
|
set(
|
||||||
|
[dp["description"] for dp in edges_data if dp.get("description")]
|
||||||
|
+ already_description
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
keywords = GRAPH_FIELD_SEP.join(
|
keywords = GRAPH_FIELD_SEP.join(
|
||||||
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
sorted(
|
||||||
|
set(
|
||||||
|
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
||||||
|
+ already_keywords
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
source_id = GRAPH_FIELD_SEP.join(
|
source_id = GRAPH_FIELD_SEP.join(
|
||||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
set(
|
||||||
|
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||||||
|
+ already_source_ids
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
await knowledge_graph_inst.upsert_node(
|
await knowledge_graph_inst.upsert_node(
|
||||||
@@ -295,9 +337,9 @@ async def extract_entities(
|
|||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
llm_response_cache: BaseKVStorage = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> Union[BaseGraphStorage, None]:
|
) -> BaseGraphStorage | None:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||||
@@ -563,15 +605,15 @@ async def extract_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
query,
|
query: str,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
prompt: str = "",
|
prompt: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
@@ -681,8 +723,8 @@ async def kg_query(
|
|||||||
async def extract_keywords_only(
|
async def extract_keywords_only(
|
||||||
text: str,
|
text: str,
|
||||||
param: QueryParam,
|
param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||||
@@ -778,9 +820,9 @@ async def mix_kg_vector_query(
|
|||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||||
|
|
||||||
@@ -1499,13 +1541,13 @@ def combine_contexts(entities, relationships, sources):
|
|||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query,
|
query: str,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
):
|
) -> str | AsyncIterator[str]:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
@@ -1606,9 +1648,9 @@ async def kg_query_with_keywords(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Refactored kg_query that does NOT extract keywords by itself.
|
Refactored kg_query that does NOT extract keywords by itself.
|
||||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
GRAPH_FIELD_SEP = "<SEP>"
|
GRAPH_FIELD_SEP = "<SEP>"
|
||||||
|
|
||||||
PROMPTS = {}
|
PROMPTS = {}
|
||||||
|
@@ -1,26 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Dict, Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordExtractionFormat(BaseModel):
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
high_level_keywords: List[str]
|
high_level_keywords: list[str]
|
||||||
low_level_keywords: List[str]
|
low_level_keywords: list[str]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphNode(BaseModel):
|
class KnowledgeGraphNode(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
labels: List[str]
|
labels: list[str]
|
||||||
properties: Dict[str, Any] # anything else goes here
|
properties: dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphEdge(BaseModel):
|
class KnowledgeGraphEdge(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
type: str
|
type: Optional[str]
|
||||||
source: str # id of source node
|
source: str # id of source node
|
||||||
target: str # id of target node
|
target: str # id of target node
|
||||||
properties: Dict[str, Any] # anything else goes here
|
properties: dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class KnowledgeGraph(BaseModel):
|
||||||
nodes: List[KnowledgeGraphNode] = []
|
nodes: list[KnowledgeGraphNode] = []
|
||||||
edges: List[KnowledgeGraphEdge] = []
|
edges: list[KnowledgeGraphEdge] = []
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import io
|
import io
|
||||||
@@ -9,7 +11,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union, List, Optional
|
from typing import Any, Callable
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import bs4
|
import bs4
|
||||||
|
|
||||||
@@ -67,12 +69,12 @@ class EmbeddingFunc:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReasoningResponse:
|
class ReasoningResponse:
|
||||||
reasoning_content: str
|
reasoning_content: str | None
|
||||||
response_content: str
|
response_content: str
|
||||||
tag: str
|
tag: str
|
||||||
|
|
||||||
|
|
||||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
def locate_json_string_body_from_string(content: str) -> str | None:
|
||||||
"""Locate the JSON string body from a string"""
|
"""Locate the JSON string body from a string"""
|
||||||
try:
|
try:
|
||||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||||
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
|||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args, cache_type: str = None) -> str:
|
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||||
"""Compute a hash for the given arguments.
|
"""Compute a hash for the given arguments.
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to hash
|
*args: Arguments to hash
|
||||||
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
|||||||
return hashlib.md5(args_str.encode()).hexdigest()
|
return hashlib.md5(args_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content, prefix: str = ""):
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Compute a unique ID for a given content string.
|
||||||
|
|
||||||
|
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
||||||
|
"""
|
||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
|
|||||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||||
|
|
||||||
|
|
||||||
def is_float_regex(value):
|
def is_float_regex(value: str) -> bool:
|
||||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
def truncate_list_by_token_size(
|
||||||
|
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
||||||
|
) -> list[int]:
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
return []
|
return []
|
||||||
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = csv.writer(
|
writer = csv.writer(
|
||||||
output,
|
output,
|
||||||
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
|||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
||||||
# Clean the string by removing NUL characters
|
# Clean the string by removing NUL characters
|
||||||
cleaned_string = csv_string.replace("\0", "")
|
cleaned_string = csv_string.replace("\0", "")
|
||||||
|
|
||||||
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_combine_contexts(hl, ll):
|
def process_combine_contexts(hl: str, ll: str):
|
||||||
header = None
|
header = None
|
||||||
list_hl = csv_string_to_list(hl.strip())
|
list_hl = csv_string_to_list(hl.strip())
|
||||||
list_ll = csv_string_to_list(ll.strip())
|
list_ll = csv_string_to_list(ll.strip())
|
||||||
@@ -375,7 +384,7 @@ async def get_best_cached_response(
|
|||||||
llm_func=None,
|
llm_func=None,
|
||||||
original_prompt=None,
|
original_prompt=None,
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
) -> Union[str, None]:
|
) -> str | None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||||
)
|
)
|
||||||
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
|
|||||||
return dot_product / (norm1 * norm2)
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
||||||
"""Quantize embedding to specified bits"""
|
"""Quantize embedding to specified bits"""
|
||||||
# Convert list to numpy array if needed
|
# Convert list to numpy array if needed
|
||||||
if isinstance(embedding, list):
|
if isinstance(embedding, list):
|
||||||
@@ -570,9 +579,9 @@ class CacheData:
|
|||||||
args_hash: str
|
args_hash: str
|
||||||
content: str
|
content: str
|
||||||
prompt: str
|
prompt: str
|
||||||
quantized: Optional[np.ndarray] = None
|
quantized: np.ndarray | None = None
|
||||||
min_val: Optional[float] = None
|
min_val: float | None = None
|
||||||
max_val: Optional[float] = None
|
max_val: float | None = None
|
||||||
mode: str = "default"
|
mode: str = "default"
|
||||||
cache_type: str = "query"
|
cache_type: str = "query"
|
||||||
|
|
||||||
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
|
def get_conversation_turns(
|
||||||
|
conversation_history: list[dict[str, Any]], num_turns: int
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process conversation history to get the specified number of complete turns.
|
Process conversation history to get the specified number of complete turns.
|
||||||
|
|
||||||
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
Formatted string of the conversation history
|
Formatted string of the conversation history
|
||||||
"""
|
"""
|
||||||
# Group messages into turns
|
# Group messages into turns
|
||||||
turns = []
|
turns: list[list[dict[str, Any]]] = []
|
||||||
messages = []
|
messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# First, filter out keyword extraction messages
|
# First, filter out keyword extraction messages
|
||||||
for msg in conversation_history:
|
for msg in conversation_history:
|
||||||
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
turns = turns[-num_turns:]
|
turns = turns[-num_turns:]
|
||||||
|
|
||||||
# Format the turns into a string
|
# Format the turns into a string
|
||||||
formatted_turns = []
|
formatted_turns: list[str] = []
|
||||||
for turn in turns:
|
for turn in turns:
|
||||||
formatted_turns.extend(
|
formatted_turns.extend(
|
||||||
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
||||||
|
Reference in New Issue
Block a user