Merge branch 'main' into add-env-settings
This commit is contained in:
@@ -1 +1,63 @@
|
|||||||
.env
|
# Python-related files and directories
|
||||||
|
__pycache__
|
||||||
|
.cache
|
||||||
|
|
||||||
|
# Virtual environment directories
|
||||||
|
*.venv
|
||||||
|
|
||||||
|
# Env
|
||||||
|
env/
|
||||||
|
*.env*
|
||||||
|
.env_example
|
||||||
|
|
||||||
|
# Distribution / build files
|
||||||
|
site
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.eggs/
|
||||||
|
*.egg-info/
|
||||||
|
*.tgz
|
||||||
|
*.tar.gz
|
||||||
|
|
||||||
|
# Exclude siles and folders
|
||||||
|
*.yml
|
||||||
|
.dockerignore
|
||||||
|
Dockerfile
|
||||||
|
Makefile
|
||||||
|
|
||||||
|
# Exclude other projects
|
||||||
|
/tests
|
||||||
|
/scripts
|
||||||
|
|
||||||
|
# Python version manager file
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# Reports
|
||||||
|
*.coverage/
|
||||||
|
*.log
|
||||||
|
log/
|
||||||
|
*.logfire
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
.cache/
|
||||||
|
.mypy_cache
|
||||||
|
.pytest_cache
|
||||||
|
.ruff_cache
|
||||||
|
.gradio
|
||||||
|
.logfire
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# MacOS-related files
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# VS Code settings (local configuration files)
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
# file
|
||||||
|
TODO.md
|
||||||
|
|
||||||
|
# Exclude Git-related files
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
.gitignore
|
||||||
|
.pre-commit-config.yaml
|
||||||
|
34
.gitignore
vendored
34
.gitignore
vendored
@@ -35,23 +35,27 @@ temp/
|
|||||||
|
|
||||||
# IDE / Editor Files
|
# IDE / Editor Files
|
||||||
.idea/
|
.idea/
|
||||||
dist/
|
.vscode/
|
||||||
env/
|
.vscode/settings.json
|
||||||
|
|
||||||
|
# Framework-specific files
|
||||||
local_neo4jWorkDir/
|
local_neo4jWorkDir/
|
||||||
neo4jWorkDir/
|
neo4jWorkDir/
|
||||||
ignore_this.txt
|
|
||||||
.venv/
|
# Data & Storage
|
||||||
*.ignore.*
|
inputs/
|
||||||
.ruff_cache/
|
rag_storage/
|
||||||
gui/
|
|
||||||
*.log
|
|
||||||
.vscode
|
|
||||||
inputs
|
|
||||||
rag_storage
|
|
||||||
.env
|
|
||||||
venv/
|
|
||||||
examples/input/
|
examples/input/
|
||||||
examples/output/
|
examples/output/
|
||||||
|
|
||||||
|
# Miscellaneous
|
||||||
.DS_Store
|
.DS_Store
|
||||||
#Remove config.ini from repo
|
TODO.md
|
||||||
*.ini
|
ignore_this.txt
|
||||||
|
*.ignore.*
|
||||||
|
|
||||||
|
# Project-specific files
|
||||||
|
dickens/
|
||||||
|
book.txt
|
||||||
|
lightrag-dev/
|
||||||
|
gui/
|
||||||
|
@@ -237,7 +237,7 @@ rag = LightRAG(
|
|||||||
|
|
||||||
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
||||||
```python
|
```python
|
||||||
from lightrag.llm import hf_model_complete, hf_embedding
|
from lightrag.llm import hf_model_complete, hf_embed
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
@@ -250,7 +250,7 @@ rag = LightRAG(
|
|||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
max_token_size=5000,
|
max_token_size=5000,
|
||||||
func=lambda texts: hf_embedding(
|
func=lambda texts: hf_embed(
|
||||||
texts,
|
texts,
|
||||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
||||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
@@ -428,9 +428,9 @@ And using a routine to process news documents.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
rag = LightRAG(..)
|
rag = LightRAG(..)
|
||||||
await rag.apipeline_enqueue_documents(string_or_strings)
|
await rag.apipeline_enqueue_documents(input)
|
||||||
# Your routine in loop
|
# Your routine in loop
|
||||||
await rag.apipeline_process_enqueue_documents(string_or_strings)
|
await rag.apipeline_process_enqueue_documents(input)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Separate Keyword Extraction
|
### Separate Keyword Extraction
|
||||||
|
@@ -113,7 +113,24 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||||
rag.set_storage_client(db_client=oracle_db)
|
|
||||||
|
for storage in [
|
||||||
|
rag.vector_db_storage_cls,
|
||||||
|
rag.graph_storage_cls,
|
||||||
|
rag.doc_status,
|
||||||
|
rag.full_docs,
|
||||||
|
rag.text_chunks,
|
||||||
|
rag.llm_response_cache,
|
||||||
|
rag.key_string_value_json_storage_cls,
|
||||||
|
rag.chunks_vdb,
|
||||||
|
rag.relationships_vdb,
|
||||||
|
rag.entities_vdb,
|
||||||
|
rag.graph_storage_cls,
|
||||||
|
rag.chunk_entity_relation_graph,
|
||||||
|
rag.llm_response_cache,
|
||||||
|
]:
|
||||||
|
# set client
|
||||||
|
storage.db = oracle_db
|
||||||
|
|
||||||
# Extract and Insert into LightRAG storage
|
# Extract and Insert into LightRAG storage
|
||||||
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
||||||
|
@@ -15,6 +15,12 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
# ChromaDB Configuration
|
# ChromaDB Configuration
|
||||||
|
CHROMADB_USE_LOCAL_PERSISTENT = False
|
||||||
|
# Local PersistentClient Configuration
|
||||||
|
CHROMADB_LOCAL_PATH = os.environ.get(
|
||||||
|
"CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
|
||||||
|
)
|
||||||
|
# Remote HttpClient Configuration
|
||||||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
||||||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
||||||
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
||||||
@@ -60,30 +66,50 @@ async def create_embedding_function_instance():
|
|||||||
|
|
||||||
async def initialize_rag():
|
async def initialize_rag():
|
||||||
embedding_func_instance = await create_embedding_function_instance()
|
embedding_func_instance = await create_embedding_function_instance()
|
||||||
|
if CHROMADB_USE_LOCAL_PERSISTENT:
|
||||||
return LightRAG(
|
return LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=gpt_4o_mini_complete,
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
embedding_func=embedding_func_instance,
|
embedding_func=embedding_func_instance,
|
||||||
vector_storage="ChromaVectorDBStorage",
|
vector_storage="ChromaVectorDBStorage",
|
||||||
log_level="DEBUG",
|
log_level="DEBUG",
|
||||||
embedding_batch_num=32,
|
embedding_batch_num=32,
|
||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"host": CHROMADB_HOST,
|
"local_path": CHROMADB_LOCAL_PATH,
|
||||||
"port": CHROMADB_PORT,
|
"collection_settings": {
|
||||||
"auth_token": CHROMADB_AUTH_TOKEN,
|
"hnsw:space": "cosine",
|
||||||
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
"hnsw:construction_ef": 128,
|
||||||
"auth_header_name": CHROMADB_AUTH_HEADER,
|
"hnsw:search_ef": 128,
|
||||||
"collection_settings": {
|
"hnsw:M": 16,
|
||||||
"hnsw:space": "cosine",
|
"hnsw:batch_size": 100,
|
||||||
"hnsw:construction_ef": 128,
|
"hnsw:sync_threshold": 1000,
|
||||||
"hnsw:search_ef": 128,
|
},
|
||||||
"hnsw:M": 16,
|
|
||||||
"hnsw:batch_size": 100,
|
|
||||||
"hnsw:sync_threshold": 1000,
|
|
||||||
},
|
},
|
||||||
},
|
)
|
||||||
)
|
else:
|
||||||
|
return LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=embedding_func_instance,
|
||||||
|
vector_storage="ChromaVectorDBStorage",
|
||||||
|
log_level="DEBUG",
|
||||||
|
embedding_batch_num=32,
|
||||||
|
vector_db_storage_cls_kwargs={
|
||||||
|
"host": CHROMADB_HOST,
|
||||||
|
"port": CHROMADB_PORT,
|
||||||
|
"auth_token": CHROMADB_AUTH_TOKEN,
|
||||||
|
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
||||||
|
"auth_header_name": CHROMADB_AUTH_HEADER,
|
||||||
|
"collection_settings": {
|
||||||
|
"hnsw:space": "cosine",
|
||||||
|
"hnsw:construction_ef": 128,
|
||||||
|
"hnsw:search_ef": 128,
|
||||||
|
"hnsw:M": 16,
|
||||||
|
"hnsw:batch_size": 100,
|
||||||
|
"hnsw:sync_threshold": 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Run the initialization
|
# Run the initialization
|
||||||
|
@@ -1,358 +0,0 @@
|
|||||||
"""
|
|
||||||
OpenWebui Lightrag Integration Tool
|
|
||||||
==================================
|
|
||||||
|
|
||||||
This tool enables the integration and use of Lightrag within the OpenWebui environment,
|
|
||||||
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
|
|
||||||
|
|
||||||
Author: ParisNeo (parisneoai@gmail.com)
|
|
||||||
Social:
|
|
||||||
- Twitter: @ParisNeo_AI
|
|
||||||
- Reddit: r/lollms
|
|
||||||
- Instagram: https://www.instagram.com/parisneo_ai/
|
|
||||||
|
|
||||||
License: Apache 2.0
|
|
||||||
Copyright (c) 2024-2025 ParisNeo
|
|
||||||
|
|
||||||
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
|
|
||||||
For more information, visit: https://github.com/ParisNeo/lollms
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- Python 3.8+
|
|
||||||
- OpenWebui
|
|
||||||
- Lightrag
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Tool version
|
|
||||||
__version__ = "1.0.0"
|
|
||||||
__author__ = "ParisNeo"
|
|
||||||
__author_email__ = "parisneoai@gmail.com"
|
|
||||||
__description__ = "Lightrag integration for OpenWebui"
|
|
||||||
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Callable, Any, Literal, Union, List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class StatusEventEmitter:
|
|
||||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit(self, description="Unknown State", status="in_progress", done=False):
|
|
||||||
if self.event_emitter:
|
|
||||||
await self.event_emitter(
|
|
||||||
{
|
|
||||||
"type": "status",
|
|
||||||
"data": {
|
|
||||||
"status": status,
|
|
||||||
"description": description,
|
|
||||||
"done": done,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageEventEmitter:
|
|
||||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit(self, content="Some message"):
|
|
||||||
if self.event_emitter:
|
|
||||||
await self.event_emitter(
|
|
||||||
{
|
|
||||||
"type": "message",
|
|
||||||
"data": {
|
|
||||||
"content": content,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Tools:
|
|
||||||
class Valves(BaseModel):
|
|
||||||
LIGHTRAG_SERVER_URL: str = Field(
|
|
||||||
default="http://localhost:9621/query",
|
|
||||||
description="The base URL for the LightRag server",
|
|
||||||
)
|
|
||||||
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
|
|
||||||
default="hybrid",
|
|
||||||
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
|
|
||||||
)
|
|
||||||
ONLY_NEED_CONTEXT: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="If True, only the context is needed from the LightRag response",
|
|
||||||
)
|
|
||||||
DEBUG_MODE: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="If True, debugging information will be emitted",
|
|
||||||
)
|
|
||||||
KEY: str = Field(
|
|
||||||
default="",
|
|
||||||
description="Optional Bearer Key for authentication",
|
|
||||||
)
|
|
||||||
MAX_ENTITIES: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum number of entities to keep",
|
|
||||||
)
|
|
||||||
MAX_RELATIONSHIPS: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum number of relationships to keep",
|
|
||||||
)
|
|
||||||
MAX_SOURCES: int = Field(
|
|
||||||
default=3,
|
|
||||||
description="Maximum number of sources to keep",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.valves = self.Valves()
|
|
||||||
self.headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"User-Agent": "LightRag-Tool/1.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def query_lightrag(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
__event_emitter__: Callable[[dict], Any] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Query the LightRag server and retrieve information.
|
|
||||||
This function must be called before answering the user question
|
|
||||||
:params query: The query string to send to the LightRag server.
|
|
||||||
:return: The response from the LightRag server in Markdown format or raw response.
|
|
||||||
"""
|
|
||||||
self.status_emitter = StatusEventEmitter(__event_emitter__)
|
|
||||||
self.message_emitter = MessageEventEmitter(__event_emitter__)
|
|
||||||
|
|
||||||
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
|
|
||||||
payload = {
|
|
||||||
"query": query,
|
|
||||||
"mode": str(self.valves.MODE),
|
|
||||||
"stream": False,
|
|
||||||
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
|
|
||||||
}
|
|
||||||
await self.status_emitter.emit("Initializing Lightrag query..")
|
|
||||||
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
|
|
||||||
)
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### Payload Sent to LightRag Server\n```json\n"
|
|
||||||
+ json.dumps(payload, indent=4)
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add Bearer Key to headers if provided
|
|
||||||
if self.valves.KEY:
|
|
||||||
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.status_emitter.emit("Sending request to LightRag server")
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
lightrag_url, json=payload, headers=self.headers, timeout=120
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
await self.status_emitter.emit(
|
|
||||||
status="complete",
|
|
||||||
description="LightRag query Succeeded",
|
|
||||||
done=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
|
|
||||||
if self.valves.ONLY_NEED_CONTEXT:
|
|
||||||
try:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### LightRag Server Response\n```json\n"
|
|
||||||
+ data["response"]
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### Exception\n" + str(ex) + "\n"
|
|
||||||
)
|
|
||||||
return f"Exception: {ex}"
|
|
||||||
return data["response"]
|
|
||||||
else:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### LightRag Server Response\n```json\n"
|
|
||||||
+ data["response"]
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
await self.status_emitter.emit("Lightrag query success")
|
|
||||||
return data["response"]
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
await self.status_emitter.emit(
|
|
||||||
status="error",
|
|
||||||
description=f"Error during LightRag query: {str(e)}",
|
|
||||||
done=True,
|
|
||||||
)
|
|
||||||
return json.dumps({"error": str(e)})
|
|
||||||
|
|
||||||
def extract_code_blocks(
|
|
||||||
self, text: str, return_remaining_text: bool = False
|
|
||||||
) -> Union[List[dict], Tuple[List[dict], str]]:
|
|
||||||
"""
|
|
||||||
This function extracts code blocks from a given text and optionally returns the text without code blocks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
|
|
||||||
return_remaining_text (bool): If True, also returns the text with code blocks removed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[dict], Tuple[List[dict], str]]:
|
|
||||||
- If return_remaining_text is False: Returns only the list of code block dictionaries
|
|
||||||
- If return_remaining_text is True: Returns a tuple containing:
|
|
||||||
* List of code block dictionaries
|
|
||||||
* String containing the text with all code blocks removed
|
|
||||||
|
|
||||||
Each code block dictionary contains:
|
|
||||||
- 'index' (int): The index of the code block in the text
|
|
||||||
- 'file_name' (str): The name of the file extracted from the preceding line, if available
|
|
||||||
- 'content' (str): The content of the code block
|
|
||||||
- 'type' (str): The type of the code block
|
|
||||||
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
|
|
||||||
"""
|
|
||||||
remaining = text
|
|
||||||
bloc_index = 0
|
|
||||||
first_index = 0
|
|
||||||
indices = []
|
|
||||||
text_without_blocks = text
|
|
||||||
|
|
||||||
# Find all code block delimiters
|
|
||||||
while len(remaining) > 0:
|
|
||||||
try:
|
|
||||||
index = remaining.index("```")
|
|
||||||
indices.append(index + first_index)
|
|
||||||
remaining = remaining[index + 3 :]
|
|
||||||
first_index += index + 3
|
|
||||||
bloc_index += 1
|
|
||||||
except Exception:
|
|
||||||
if bloc_index % 2 == 1:
|
|
||||||
index = len(remaining)
|
|
||||||
indices.append(index)
|
|
||||||
remaining = ""
|
|
||||||
|
|
||||||
code_blocks = []
|
|
||||||
is_start = True
|
|
||||||
|
|
||||||
# Process code blocks and build text without blocks if requested
|
|
||||||
if return_remaining_text:
|
|
||||||
text_parts = []
|
|
||||||
last_end = 0
|
|
||||||
|
|
||||||
for index, code_delimiter_position in enumerate(indices):
|
|
||||||
if is_start:
|
|
||||||
block_infos = {
|
|
||||||
"index": len(code_blocks),
|
|
||||||
"file_name": "",
|
|
||||||
"section": "",
|
|
||||||
"content": "",
|
|
||||||
"type": "",
|
|
||||||
"is_complete": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store text before code block if returning remaining text
|
|
||||||
if return_remaining_text:
|
|
||||||
text_parts.append(text[last_end:code_delimiter_position].strip())
|
|
||||||
|
|
||||||
# Check the preceding line for file name
|
|
||||||
preceding_text = text[:code_delimiter_position].strip().splitlines()
|
|
||||||
if preceding_text:
|
|
||||||
last_line = preceding_text[-1].strip()
|
|
||||||
if last_line.startswith("<file_name>") and last_line.endswith(
|
|
||||||
"</file_name>"
|
|
||||||
):
|
|
||||||
file_name = last_line[
|
|
||||||
len("<file_name>") : -len("</file_name>")
|
|
||||||
].strip()
|
|
||||||
block_infos["file_name"] = file_name
|
|
||||||
elif last_line.startswith("## filename:"):
|
|
||||||
file_name = last_line[len("## filename:") :].strip()
|
|
||||||
block_infos["file_name"] = file_name
|
|
||||||
if last_line.startswith("<section>") and last_line.endswith(
|
|
||||||
"</section>"
|
|
||||||
):
|
|
||||||
section = last_line[
|
|
||||||
len("<section>") : -len("</section>")
|
|
||||||
].strip()
|
|
||||||
block_infos["section"] = section
|
|
||||||
|
|
||||||
sub_text = text[code_delimiter_position + 3 :]
|
|
||||||
if len(sub_text) > 0:
|
|
||||||
try:
|
|
||||||
find_space = sub_text.index(" ")
|
|
||||||
except Exception:
|
|
||||||
find_space = int(1e10)
|
|
||||||
try:
|
|
||||||
find_return = sub_text.index("\n")
|
|
||||||
except Exception:
|
|
||||||
find_return = int(1e10)
|
|
||||||
next_index = min(find_return, find_space)
|
|
||||||
if "{" in sub_text[:next_index]:
|
|
||||||
next_index = 0
|
|
||||||
start_pos = next_index
|
|
||||||
|
|
||||||
if code_delimiter_position + 3 < len(text) and text[
|
|
||||||
code_delimiter_position + 3
|
|
||||||
] in ["\n", " ", "\t"]:
|
|
||||||
block_infos["type"] = "language-specific"
|
|
||||||
else:
|
|
||||||
block_infos["type"] = sub_text[:next_index]
|
|
||||||
|
|
||||||
if index + 1 < len(indices):
|
|
||||||
next_pos = indices[index + 1] - code_delimiter_position
|
|
||||||
if (
|
|
||||||
next_pos - 3 < len(sub_text)
|
|
||||||
and sub_text[next_pos - 3] == "`"
|
|
||||||
):
|
|
||||||
block_infos["content"] = sub_text[
|
|
||||||
start_pos : next_pos - 3
|
|
||||||
].strip()
|
|
||||||
block_infos["is_complete"] = True
|
|
||||||
else:
|
|
||||||
block_infos["content"] = sub_text[
|
|
||||||
start_pos:next_pos
|
|
||||||
].strip()
|
|
||||||
block_infos["is_complete"] = False
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
last_end = indices[index + 1] + 3
|
|
||||||
else:
|
|
||||||
block_infos["content"] = sub_text[start_pos:].strip()
|
|
||||||
block_infos["is_complete"] = False
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
last_end = len(text)
|
|
||||||
|
|
||||||
code_blocks.append(block_infos)
|
|
||||||
is_start = False
|
|
||||||
else:
|
|
||||||
is_start = True
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
# Add any remaining text after the last code block
|
|
||||||
if last_end < len(text):
|
|
||||||
text_parts.append(text[last_end:].strip())
|
|
||||||
# Join all non-code parts with newlines
|
|
||||||
text_without_blocks = "\n".join(filter(None, text_parts))
|
|
||||||
return code_blocks, text_without_blocks
|
|
||||||
|
|
||||||
return code_blocks
|
|
||||||
|
|
||||||
def clean(self, csv_content: str):
|
|
||||||
lines = csv_content.splitlines()
|
|
||||||
if lines:
|
|
||||||
# Remove spaces around headers and ensure no spaces between commas
|
|
||||||
header = ",".join([col.strip() for col in lines[0].split(",")])
|
|
||||||
lines[0] = header # Replace the first line with the cleaned header
|
|
||||||
csv_content = "\n".join(lines)
|
|
||||||
return csv_content
|
|
@@ -185,7 +185,8 @@ TiDBVectorDBStorage TiDB
|
|||||||
PGVectorStorage Postgres
|
PGVectorStorage Postgres
|
||||||
FaissVectorDBStorage Faiss
|
FaissVectorDBStorage Faiss
|
||||||
QdrantVectorDBStorage Qdrant
|
QdrantVectorDBStorage Qdrant
|
||||||
OracleVectorDBStorag Oracle
|
OracleVectorDBStorage Oracle
|
||||||
|
MongoVectorDBStorage MongoDB
|
||||||
```
|
```
|
||||||
|
|
||||||
* DOC_STATUS_STORAGE:supported implement-name
|
* DOC_STATUS_STORAGE:supported implement-name
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -5,10 +7,8 @@ from enum import Enum
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
@@ -72,7 +72,7 @@ class QueryParam:
|
|||||||
ll_keywords: list[str] = field(default_factory=list)
|
ll_keywords: list[str] = field(default_factory=list)
|
||||||
"""List of low-level keywords to refine retrieval focus."""
|
"""List of low-level keywords to refine retrieval focus."""
|
||||||
|
|
||||||
conversation_history: list[dict[str, Any]] = field(default_factory=list)
|
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
||||||
"""Stores past conversation history to maintain context.
|
"""Stores past conversation history to maintain context.
|
||||||
Format: [{"role": "user/assistant", "content": "message"}].
|
Format: [{"role": "user/assistant", "content": "message"}].
|
||||||
"""
|
"""
|
||||||
@@ -86,19 +86,15 @@ class StorageNameSpace:
|
|||||||
namespace: str
|
namespace: str
|
||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Commit the storage operations after indexing"""
|
"""Commit the storage operations after indexing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def query_done_callback(self):
|
|
||||||
"""Commit the storage operations after querying"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
meta_fields: set = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -109,12 +105,20 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
"""Delete a single entity by its name"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
"""Delete relations for a given entity by scanning metadata"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(StorageNameSpace):
|
class BaseKVStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc | None = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
@@ -133,50 +137,75 @@ class BaseKVStorage(StorageNameSpace):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace):
|
class BaseGraphStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc = None
|
embedding_func: EmbeddingFunc | None = None
|
||||||
|
"""Check if a node exists in the graph."""
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Check if an edge exists in the graph."""
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get the degree of a node."""
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get the degree of an edge."""
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
"""Get a node by its id."""
|
||||||
|
|
||||||
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get an edge by its source and target node ids."""
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node_edges(
|
"""Get all edges connected to a node."""
|
||||||
self, source_node_id: str
|
|
||||||
) -> Union[list[tuple[str, str]], None]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
"""Upsert a node into the graph."""
|
||||||
|
|
||||||
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Upsert an edge into the graph."""
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
"""Delete a node from the graph."""
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
"""Embed nodes using an algorithm."""
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
|
||||||
|
"""Get all labels in the graph."""
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
"""Get a knowledge graph of a node."""
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
@@ -208,9 +237,9 @@ class DocProcessingStatus:
|
|||||||
"""ISO format timestamp when document was created"""
|
"""ISO format timestamp when document was created"""
|
||||||
updated_at: str
|
updated_at: str
|
||||||
"""ISO format timestamp when document was last updated"""
|
"""ISO format timestamp when document was last updated"""
|
||||||
chunks_count: Optional[int] = None
|
chunks_count: int | None = None
|
||||||
"""Number of chunks after splitting, used for processing"""
|
"""Number of chunks after splitting, used for processing"""
|
||||||
error: Optional[str] = None
|
error: str | None = None
|
||||||
"""Error message if failed"""
|
"""Error message if failed"""
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Additional metadata"""
|
"""Additional metadata"""
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chromadb import HttpClient
|
from chromadb import HttpClient, PersistentClient
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
@@ -49,31 +49,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
**user_collection_settings,
|
**user_collection_settings,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth_provider = config.get(
|
local_path = config.get("local_path", None)
|
||||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
if local_path:
|
||||||
)
|
self._client = PersistentClient(
|
||||||
auth_credentials = config.get("auth_token", "secret-token")
|
path=local_path,
|
||||||
headers = {}
|
settings=Settings(
|
||||||
|
allow_reset=True,
|
||||||
|
anonymized_telemetry=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
auth_provider = config.get(
|
||||||
|
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||||
|
)
|
||||||
|
auth_credentials = config.get("auth_token", "secret-token")
|
||||||
|
headers = {}
|
||||||
|
|
||||||
if "token_authn" in auth_provider:
|
if "token_authn" in auth_provider:
|
||||||
headers = {
|
headers = {
|
||||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
config.get(
|
||||||
}
|
"auth_header_name", "X-Chroma-Token"
|
||||||
elif "basic_authn" in auth_provider:
|
): auth_credentials
|
||||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
}
|
||||||
|
elif "basic_authn" in auth_provider:
|
||||||
|
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||||
|
|
||||||
self._client = HttpClient(
|
self._client = HttpClient(
|
||||||
host=config.get("host", "localhost"),
|
host=config.get("host", "localhost"),
|
||||||
port=config.get("port", 8000),
|
port=config.get("port", 8000),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
settings=Settings(
|
settings=Settings(
|
||||||
chroma_api_impl="rest",
|
chroma_api_impl="rest",
|
||||||
chroma_client_auth_provider=auth_provider,
|
chroma_client_auth_provider=auth_provider,
|
||||||
chroma_client_auth_credentials=auth_credentials,
|
chroma_client_auth_credentials=auth_credentials,
|
||||||
allow_reset=True,
|
allow_reset=True,
|
||||||
anonymized_telemetry=False,
|
anonymized_telemetry=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._collection = self._client.get_or_create_collection(
|
self._collection = self._client.get_or_create_collection(
|
||||||
name=self.namespace,
|
name=self.namespace,
|
||||||
@@ -144,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
results = self._collection.query(
|
results = self._collection.query(
|
||||||
query_embeddings=embedding.tolist(),
|
query_embeddings=embedding.tolist()
|
||||||
|
if not isinstance(embedding, list)
|
||||||
|
else embedding,
|
||||||
n_results=top_k * 2, # Request more results to allow for filtering
|
n_results=top_k * 2, # Request more results to allow for filtering
|
||||||
include=["metadatas", "distances", "documents"],
|
include=["metadatas", "distances", "documents"],
|
||||||
)
|
)
|
||||||
|
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||||
await self.delete([entity_id])
|
await self.delete([entity_id])
|
||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str):
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Delete relations for a given entity by scanning metadata.
|
Delete relations for a given entity by scanning metadata.
|
||||||
"""
|
"""
|
||||||
|
@@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
|
async def delete(self, ids: list[str]) -> None:
|
||||||
|
for doc_id in ids:
|
||||||
|
self._data.pop(doc_id, None)
|
||||||
|
await self.index_done_callback()
|
||||||
|
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import configparser
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
import asyncio
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
pm.install("pymongo")
|
pm.install("pymongo")
|
||||||
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
|
|||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
|
BaseVectorStorage,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
uri = os.environ.get(
|
||||||
os.environ.get(
|
"MONGO_URI",
|
||||||
"MONGO_URI",
|
config.get(
|
||||||
config.get(
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
),
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
client = AsyncIOMotorClient(uri)
|
||||||
database = client.get_database(
|
database = client.get_database(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._data = database.get_collection(self.namespace)
|
|
||||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
self._collection_name = self.namespace
|
||||||
|
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||||
str(x["_id"])
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
return data - existing_ids
|
||||||
]
|
|
||||||
return set([s for s in data if s not in existing_ids])
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
|
update_tasks = []
|
||||||
for mode, items in data.items():
|
for mode, items in data.items():
|
||||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
for k, v in items.items():
|
||||||
key = f"{mode}_{k}"
|
key = f"{mode}_{k}"
|
||||||
result = self._data.update_one(
|
data[mode][k]["_id"] = f"{mode}_{k}"
|
||||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
update_tasks.append(
|
||||||
|
self._data.update_one(
|
||||||
|
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if result.upserted_id:
|
await asyncio.gather(*update_tasks)
|
||||||
logger.debug(f"\nInserted new document with key: {key}")
|
|
||||||
data[mode][k]["_id"] = key
|
|
||||||
else:
|
else:
|
||||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
update_tasks = []
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
for k, v in data.items():
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
res = {}
|
res = {}
|
||||||
v = self._data.find_one({"_id": mode + "_" + id})
|
v = await self._data.find_one({"_id": mode + "_" + id})
|
||||||
if v:
|
if v:
|
||||||
res[id] = v
|
res[id] = v
|
||||||
logger.debug(f"llm_response_cache find one by:{id}")
|
logger.debug(f"llm_response_cache find one by:{id}")
|
||||||
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
uri = os.environ.get(
|
||||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
client = AsyncIOMotorClient(uri)
|
||||||
self._data = database.get_collection(self.namespace)
|
database = client.get_database(
|
||||||
logger.info(f"Use MongoDB as doc status {self.namespace}")
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
return await cursor.to_list()
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
||||||
str(x["_id"])
|
existing_ids = {str(x["_id"]) async for x in cursor}
|
||||||
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
return data - existing_ids
|
||||||
]
|
|
||||||
return set([s for s in data if s not in existing_ids])
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
update_tasks = []
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the collection"""
|
"""Drop the collection"""
|
||||||
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
||||||
result = list(self._data.aggregate(pipeline))
|
cursor = self._data.aggregate(pipeline)
|
||||||
|
result = await cursor.to_list()
|
||||||
counts = {}
|
counts = {}
|
||||||
for doc in result:
|
for doc in result:
|
||||||
counts[doc["_id"]] = doc["count"]
|
counts[doc["_id"]] = doc["count"]
|
||||||
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents by status"""
|
"""Get all documents by status"""
|
||||||
result = list(self._data.find({"status": status.value}))
|
cursor = self._data.find({"status": status.value})
|
||||||
|
result = await cursor.to_list()
|
||||||
return {
|
return {
|
||||||
doc["_id"]: DocProcessingStatus(
|
doc["_id"]: DocProcessingStatus(
|
||||||
content=doc["content"],
|
content=doc["content"],
|
||||||
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
global_config=global_config,
|
global_config=global_config,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
self.client = AsyncIOMotorClient(
|
uri = os.environ.get(
|
||||||
os.environ.get(
|
"MONGO_URI",
|
||||||
"MONGO_URI",
|
config.get(
|
||||||
config.get(
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
),
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.db = self.client[
|
client = AsyncIOMotorClient(uri)
|
||||||
|
database = client.get_database(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
self.collection = self.db[
|
|
||||||
os.environ.get(
|
self._collection_name = self.namespace
|
||||||
"MONGO_KG_COLLECTION",
|
self.collection = database.get_collection(self._collection_name)
|
||||||
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
|
||||||
)
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||||
]
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
self, source_node_id: str
|
self, source_node_id: str
|
||||||
) -> Union[List[Tuple[str, str]], None]:
|
) -> Union[List[Tuple[str, str]], None]:
|
||||||
"""
|
"""
|
||||||
Return a list of (target_id, relation) for direct edges from source_node_id.
|
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||||
"""
|
"""
|
||||||
pipeline = [
|
pipeline = [
|
||||||
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
edges = result[0].get("edges", [])
|
edges = result[0].get("edges", [])
|
||||||
return [(e["target"], e["relation"]) for e in edges]
|
return [(source_node_id, e["target"]) for e in edges]
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str):
|
||||||
"""
|
"""
|
||||||
1) Remove node’s doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound edges from any doc that references node_id.
|
||||||
"""
|
"""
|
||||||
# Remove inbound edges from all other docs
|
# Remove inbound edges from all other docs
|
||||||
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
Placeholder for demonstration, raises NotImplementedError.
|
Placeholder for demonstration, raises NotImplementedError.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
|
||||||
|
#
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# QUERY
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all existing node _id in the database
|
||||||
|
Returns:
|
||||||
|
[id1, id2, ...] # Alphabetically sorted id list
|
||||||
|
"""
|
||||||
|
# Use MongoDB's distinct and aggregation to get all unique labels
|
||||||
|
pipeline = [
|
||||||
|
{"$group": {"_id": "$_id"}}, # Group by _id
|
||||||
|
{"$sort": {"_id": 1}}, # Sort alphabetically
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = self.collection.aggregate(pipeline)
|
||||||
|
labels = []
|
||||||
|
async for doc in cursor:
|
||||||
|
labels.append(doc["_id"])
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_label: Label of the nodes to start from
|
||||||
|
max_depth: Maximum depth of traversal (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||||
|
"""
|
||||||
|
label = node_label
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if label == "*":
|
||||||
|
# Get all nodes and edges
|
||||||
|
async for node_doc in self.collection.find({}):
|
||||||
|
node_id = str(node_doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_doc.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in node_doc.items()
|
||||||
|
if k not in ["_id", "edges"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Process edges
|
||||||
|
for edge in node_doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
else:
|
||||||
|
# Verify if starting node exists
|
||||||
|
start_nodes = self.collection.find({"_id": label})
|
||||||
|
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||||
|
if not start_nodes_exist:
|
||||||
|
logger.warning(f"Starting node with label {label} does not exist!")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Use $graphLookup for traversal
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {"_id": label}
|
||||||
|
}, # Start with nodes having the specified label
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": self._collection_name,
|
||||||
|
"startWith": "$edges.target",
|
||||||
|
"connectFromField": "edges.target",
|
||||||
|
"connectToField": "_id",
|
||||||
|
"maxDepth": max_depth,
|
||||||
|
"depthField": "depth",
|
||||||
|
"as": "connected_nodes",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async for doc in self.collection.aggregate(pipeline):
|
||||||
|
# Add the start node
|
||||||
|
node_id = str(doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[
|
||||||
|
doc.get(
|
||||||
|
"_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in doc.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"edges",
|
||||||
|
"connected_nodes",
|
||||||
|
"depth",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from start node
|
||||||
|
for edge in doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
# Add connected nodes and their edges
|
||||||
|
for connected in doc.get("connected_nodes", []):
|
||||||
|
node_id = str(connected["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[connected.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in connected.items()
|
||||||
|
if k not in ["_id", "edges", "depth"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from connected nodes
|
||||||
|
for edge in connected.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"MongoDB query failed: {str(e)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
|
uri = os.environ.get(
|
||||||
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
client = AsyncIOMotorClient(uri)
|
||||||
|
database = client.get_database(
|
||||||
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._collection_name = self.namespace
|
||||||
|
self._data = database.get_collection(self._collection_name)
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
|
# Ensure vector index exists
|
||||||
|
self.create_vector_index(uri, database.name, self._collection_name)
|
||||||
|
|
||||||
|
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
||||||
|
"""Creates an Atlas Vector Search index."""
|
||||||
|
client = MongoClient(uri)
|
||||||
|
collection = client.get_database(database_name).get_collection(
|
||||||
|
self._collection_name
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_index_model = SearchIndexModel(
|
||||||
|
definition={
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"type": "vector",
|
||||||
|
"numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
|
||||||
|
"path": "vector",
|
||||||
|
"similarity": "cosine", # Options: euclidean, cosine, dotProduct
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
name="vector_knn_index",
|
||||||
|
type="vectorSearch",
|
||||||
|
)
|
||||||
|
|
||||||
|
collection.create_search_index(search_index_model)
|
||||||
|
logger.info("Vector index created successfully.")
|
||||||
|
|
||||||
|
except PyMongoError as _:
|
||||||
|
logger.debug("vector index already exist")
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
|
if not data:
|
||||||
|
logger.warning("You are inserting an empty data set to vector DB")
|
||||||
|
return []
|
||||||
|
|
||||||
|
list_data = [
|
||||||
|
{
|
||||||
|
"_id": k,
|
||||||
|
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||||
|
}
|
||||||
|
for k, v in data.items()
|
||||||
|
]
|
||||||
|
contents = [v["content"] for v in data.values()]
|
||||||
|
batches = [
|
||||||
|
contents[i : i + self._max_batch_size]
|
||||||
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def wrapped_task(batch):
|
||||||
|
result = await self.embedding_func(batch)
|
||||||
|
pbar.update(1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||||
|
pbar = tqdm_async(
|
||||||
|
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||||
|
|
||||||
|
embeddings = np.concatenate(embeddings_list)
|
||||||
|
for i, d in enumerate(list_data):
|
||||||
|
d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
|
||||||
|
|
||||||
|
update_tasks = []
|
||||||
|
for doc in list_data:
|
||||||
|
update_tasks.append(
|
||||||
|
self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*update_tasks)
|
||||||
|
|
||||||
|
return list_data
|
||||||
|
|
||||||
|
async def query(self, query, top_k=5):
|
||||||
|
"""Queries the vector database using Atlas Vector Search."""
|
||||||
|
# Generate the embedding
|
||||||
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
|
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||||
|
query_vector = embedding[0].tolist()
|
||||||
|
|
||||||
|
# Define the aggregation pipeline with the converted query vector
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$vectorSearch": {
|
||||||
|
"index": "vector_knn_index", # Ensure this matches the created index name
|
||||||
|
"path": "vector",
|
||||||
|
"queryVector": query_vector,
|
||||||
|
"numCandidates": 100, # Adjust for performance
|
||||||
|
"limit": top_k,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
|
||||||
|
{"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
|
||||||
|
{"$project": {"vector": 0}},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the aggregation pipeline
|
||||||
|
cursor = self._data.aggregate(pipeline)
|
||||||
|
results = await cursor.to_list()
|
||||||
|
|
||||||
|
# Format and return the results
|
||||||
|
return [
|
||||||
|
{**doc, "id": doc["_id"], "distance": doc.get("score", None)}
|
||||||
|
for doc in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||||
|
"""Check if the collection exists. if not, create it."""
|
||||||
|
client = MongoClient(uri)
|
||||||
|
database = client.get_database(database_name)
|
||||||
|
|
||||||
|
collection_names = database.list_collection_names()
|
||||||
|
|
||||||
|
if collection_name not in collection_names:
|
||||||
|
database.create_collection(collection_name)
|
||||||
|
logger.info(f"Created collection: {collection_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Collection '{collection_name}' already exists.")
|
||||||
|
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str):
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
relations = [
|
relations = [
|
||||||
dp
|
dp
|
||||||
|
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
print("KG successfully indexed.")
|
print("KG successfully indexed.")
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def _label_exists(self, label: str) -> bool:
|
||||||
entity_name_label = node_id.strip('"')
|
"""Check if a label exists in the Neo4j database."""
|
||||||
|
query = "CALL db.labels() YIELD label RETURN label"
|
||||||
|
try:
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
result = await session.run(query)
|
||||||
|
labels = [record["label"] for record in await result.data()]
|
||||||
|
return label in labels
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking label existence: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _ensure_label(self, label: str) -> str:
|
||||||
|
"""Ensure a label exists by validating it."""
|
||||||
|
clean_label = label.strip('"')
|
||||||
|
if not await self._label_exists(clean_label):
|
||||||
|
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
|
||||||
|
return clean_label
|
||||||
|
|
||||||
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
entity_name_label = await self._ensure_label(node_id)
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||||
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
|
"""Get node by its label identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The node label to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Node properties if found
|
||||||
|
None: If node not found
|
||||||
|
"""
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = await self._ensure_label(node_id)
|
||||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
"""Find edge between two nodes identified by their labels.
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node_label (str): Label of the source nodes
|
source_node_id (str): Label of the source node
|
||||||
target_node_label (str): Label of the target nodes
|
target_node_id (str): Label of the target node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: List of all relationships/edges found
|
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||||
|
None: If error occurs
|
||||||
"""
|
"""
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
try:
|
||||||
query = f"""
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
RETURN properties(r) as edge_properties
|
|
||||||
LIMIT 1
|
|
||||||
""".format(
|
|
||||||
entity_name_label_source=entity_name_label_source,
|
|
||||||
entity_name_label_target=entity_name_label_target,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.run(query)
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
record = await result.single()
|
query = f"""
|
||||||
if record:
|
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||||
result = dict(record["edge_properties"])
|
RETURN properties(r) as edge_properties
|
||||||
logger.debug(
|
LIMIT 1
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
""".format(
|
||||||
|
entity_name_label_source=entity_name_label_source,
|
||||||
|
entity_name_label_target=entity_name_label_target,
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
else:
|
result = await session.run(query)
|
||||||
return None
|
record = await result.single()
|
||||||
|
if record and "edge_properties" in record:
|
||||||
|
try:
|
||||||
|
result = dict(record["edge_properties"])
|
||||||
|
# Ensure required keys exist with defaults
|
||||||
|
required_keys = {
|
||||||
|
"weight": 0.0,
|
||||||
|
"source_id": None,
|
||||||
|
"target_id": None,
|
||||||
|
}
|
||||||
|
for key, default_value in required_keys.items():
|
||||||
|
if key not in result:
|
||||||
|
result[key] = default_value
|
||||||
|
logger.warning(
|
||||||
|
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||||
|
f"missing {key}, using default: {default_value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except (KeyError, TypeError, ValueError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing edge properties between {entity_name_label_source} "
|
||||||
|
f"and {entity_name_label_target}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||||
|
)
|
||||||
|
# Return default edge properties when no edge found
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
node_id: The unique identifier for the node (used as label)
|
node_id: The unique identifier for the node (used as label)
|
||||||
node_data: Dictionary of node properties
|
node_data: Dictionary of node properties
|
||||||
"""
|
"""
|
||||||
label = node_id.strip('"')
|
label = await self._ensure_label(node_id)
|
||||||
properties = node_data
|
properties = node_data
|
||||||
|
|
||||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||||
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
neo4jExceptions.WriteServiceUnavailable,
|
neo4jExceptions.WriteServiceUnavailable,
|
||||||
|
neo4jExceptions.ClientError,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id (str): Label of the target node (used as identifier)
|
||||||
edge_data (dict): Dictionary of properties to set on the edge
|
edge_data (dict): Dictionary of properties to set on the edge
|
||||||
"""
|
"""
|
||||||
source_node_label = source_node_id.strip('"')
|
source_label = await self._ensure_label(source_node_id)
|
||||||
target_node_label = target_node_id.strip('"')
|
target_label = await self._ensure_label(target_node_id)
|
||||||
edge_properties = edge_data
|
edge_properties = edge_data
|
||||||
|
|
||||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (source:`{source_node_label}`)
|
MATCH (source:`{source_label}`)
|
||||||
WITH source
|
WITH source
|
||||||
MATCH (target:`{target_node_label}`)
|
MATCH (target:`{target_label}`)
|
||||||
MERGE (source)-[r:DIRECTED]->(target)
|
MERGE (source)-[r:DIRECTED]->(target)
|
||||||
SET r += $properties
|
SET r += $properties
|
||||||
RETURN r
|
RETURN r
|
||||||
"""
|
"""
|
||||||
await tx.run(query, properties=edge_properties)
|
result = await tx.run(query, properties=edge_properties)
|
||||||
|
record = await result.single()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
|
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@@ -1,10 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import configparser
|
import configparser
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Optional, Type, Union, cast
|
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|||||||
"FaissVectorDBStorage",
|
"FaissVectorDBStorage",
|
||||||
"QdrantVectorDBStorage",
|
"QdrantVectorDBStorage",
|
||||||
"OracleVectorDBStorage",
|
"OracleVectorDBStorage",
|
||||||
|
"MongoVectorDBStorage",
|
||||||
],
|
],
|
||||||
"required_methods": ["query", "upsert"],
|
"required_methods": ["query", "upsert"],
|
||||||
},
|
},
|
||||||
@@ -91,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Storage implementation environment variable without default value
|
# Storage implementation environment variable without default value
|
||||||
STORAGE_ENV_REQUIREMENTS = {
|
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||||
# KV Storage Implementations
|
# KV Storage Implementations
|
||||||
"JsonKVStorage": [],
|
"JsonKVStorage": [],
|
||||||
"MongoKVStorage": [],
|
"MongoKVStorage": [],
|
||||||
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
|
|||||||
"ORACLE_PASSWORD",
|
"ORACLE_PASSWORD",
|
||||||
"ORACLE_CONFIG_DIR",
|
"ORACLE_CONFIG_DIR",
|
||||||
],
|
],
|
||||||
|
"MongoVectorDBStorage": [],
|
||||||
# Document Status Storage Implementations
|
# Document Status Storage Implementations
|
||||||
"JsonDocStatusStorage": [],
|
"JsonDocStatusStorage": [],
|
||||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
@@ -160,6 +164,7 @@ STORAGES = {
|
|||||||
"MongoKVStorage": ".kg.mongo_impl",
|
"MongoKVStorage": ".kg.mongo_impl",
|
||||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||||
"MongoGraphStorage": ".kg.mongo_impl",
|
"MongoGraphStorage": ".kg.mongo_impl",
|
||||||
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||||
"RedisKVStorage": ".kg.redis_impl",
|
"RedisKVStorage": ".kg.redis_impl",
|
||||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||||
"TiDBKVStorage": ".kg.tidb_impl",
|
"TiDBKVStorage": ".kg.tidb_impl",
|
||||||
@@ -176,7 +181,7 @@ STORAGES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def lazy_external_import(module_name: str, class_name: str):
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||||
"""Lazily import a class from an external module based on the package of the caller."""
|
"""Lazily import a class from an external module based on the package of the caller."""
|
||||||
# Get the caller's module and package
|
# Get the caller's module and package
|
||||||
import inspect
|
import inspect
|
||||||
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
|
|||||||
module = inspect.getmodule(caller_frame)
|
module = inspect.getmodule(caller_frame)
|
||||||
package = module.__package__ if module else None
|
package = module.__package__ if module else None
|
||||||
|
|
||||||
def import_class(*args, **kwargs):
|
def import_class(*args: Any, **kwargs: Any):
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
module = importlib.import_module(module_name, package=package)
|
module = importlib.import_module(module_name, package=package)
|
||||||
@@ -302,7 +307,7 @@ class LightRAG:
|
|||||||
- random_seed: Seed value for reproducibility.
|
- random_seed: Seed value for reproducibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding_func: EmbeddingFunc = None
|
embedding_func: EmbeddingFunc | None = None
|
||||||
"""Function for computing text embeddings. Must be set before use."""
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
@@ -312,7 +317,7 @@ class LightRAG:
|
|||||||
"""Maximum number of concurrent embedding function calls."""
|
"""Maximum number of concurrent embedding function calls."""
|
||||||
|
|
||||||
# LLM Configuration
|
# LLM Configuration
|
||||||
llm_model_func: callable = None
|
llm_model_func: Callable[..., object] | None = None
|
||||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||||
|
|
||||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
@@ -342,10 +347,8 @@ class LightRAG:
|
|||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Dictionary for additional parameters and extensions."""
|
|
||||||
|
|
||||||
# extension
|
"""Dictionary for additional parameters and extensions."""
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
|
||||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||||
convert_response_to_json
|
convert_response_to_json
|
||||||
)
|
)
|
||||||
@@ -354,7 +357,7 @@ class LightRAG:
|
|||||||
chunking_func: Callable[
|
chunking_func: Callable[
|
||||||
[
|
[
|
||||||
str,
|
str,
|
||||||
Optional[str],
|
str | None,
|
||||||
bool,
|
bool,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
@@ -443,77 +446,74 @@ class LightRAG:
|
|||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Init LLM
|
# Init LLM
|
||||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||||
self.embedding_func
|
self.embedding_func
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize all storages
|
# Initialize all storages
|
||||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||||
self._get_storage_class(self.kv_storage)
|
self._get_storage_class(self.kv_storage)
|
||||||
)
|
) # type: ignore
|
||||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
|
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
||||||
self.vector_storage
|
self.vector_storage
|
||||||
)
|
) # type: ignore
|
||||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
|
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
||||||
self.graph_storage
|
self.graph_storage
|
||||||
)
|
) # type: ignore
|
||||||
|
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||||
self.key_string_value_json_storage_cls = partial(
|
|
||||||
self.key_string_value_json_storage_cls, global_config=global_config
|
self.key_string_value_json_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
self.vector_db_storage_cls = partial( # type: ignore
|
||||||
self.vector_db_storage_cls = partial(
|
|
||||||
self.vector_db_storage_cls, global_config=global_config
|
self.vector_db_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
self.graph_storage_cls = partial( # type: ignore
|
||||||
self.graph_storage_cls = partial(
|
|
||||||
self.graph_storage_cls, global_config=global_config
|
self.graph_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize document status storage
|
# Initialize document status storage
|
||||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||||
|
|
||||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entities_vdb = self.vector_db_storage_cls(
|
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"entity_name"},
|
meta_fields={"entity_name"},
|
||||||
)
|
)
|
||||||
self.relationships_vdb = self.vector_db_storage_cls(
|
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"src_id", "tgt_id"},
|
meta_fields={"src_id", "tgt_id"},
|
||||||
)
|
)
|
||||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||||
),
|
),
|
||||||
@@ -527,13 +527,12 @@ class LightRAG:
|
|||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# What's for, Is this nessisary ?
|
|
||||||
if self.llm_response_cache and hasattr(
|
if self.llm_response_cache and hasattr(
|
||||||
self.llm_response_cache, "global_config"
|
self.llm_response_cache, "global_config"
|
||||||
):
|
):
|
||||||
hashing_kv = self.llm_response_cache
|
hashing_kv = self.llm_response_cache
|
||||||
else:
|
else:
|
||||||
hashing_kv = self.key_string_value_json_storage_cls(
|
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
@@ -542,7 +541,7 @@ class LightRAG:
|
|||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
self.llm_model_func,
|
self.llm_model_func, # type: ignore
|
||||||
hashing_kv=hashing_kv,
|
hashing_kv=hashing_kv,
|
||||||
**self.llm_model_kwargs,
|
**self.llm_model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -559,68 +558,45 @@ class LightRAG:
|
|||||||
node_label=nodel_label, max_depth=max_depth
|
node_label=nodel_label, max_depth=max_depth
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> dict:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
import_path = STORAGES[storage_name]
|
import_path = STORAGES[storage_name]
|
||||||
storage_class = lazy_external_import(import_path, storage_name)
|
storage_class = lazy_external_import(import_path, storage_name)
|
||||||
return storage_class
|
return storage_class
|
||||||
|
|
||||||
def set_storage_client(self, db_client):
|
|
||||||
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
|
||||||
# Inject db to storage implementation (only tested on Oracle Database)
|
|
||||||
for storage in [
|
|
||||||
self.vector_db_storage_cls,
|
|
||||||
self.graph_storage_cls,
|
|
||||||
self.doc_status,
|
|
||||||
self.full_docs,
|
|
||||||
self.text_chunks,
|
|
||||||
self.llm_response_cache,
|
|
||||||
self.key_string_value_json_storage_cls,
|
|
||||||
self.chunks_vdb,
|
|
||||||
self.relationships_vdb,
|
|
||||||
self.entities_vdb,
|
|
||||||
self.graph_storage_cls,
|
|
||||||
self.chunk_entity_relation_graph,
|
|
||||||
self.llm_response_cache,
|
|
||||||
]:
|
|
||||||
# set client
|
|
||||||
storage.db = db_client
|
|
||||||
|
|
||||||
def insert(
|
def insert(
|
||||||
self,
|
self,
|
||||||
string_or_strings: Union[str, list[str]],
|
input: str | list[str],
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Sync Insert documents with checkpoint support
|
"""Sync Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_or_strings: Single document string or list of document strings
|
input: Single document string or list of document strings
|
||||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||||
chunk_size, split the sub chunk by token size.
|
|
||||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||||
split_by_character is None, this parameter is ignored.
|
split_by_character is None, this parameter is ignored.
|
||||||
"""
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
self.ainsert(input, split_by_character, split_by_character_only)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ainsert(
|
async def ainsert(
|
||||||
self,
|
self,
|
||||||
string_or_strings: Union[str, list[str]],
|
input: str | list[str],
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Async Insert documents with checkpoint support
|
"""Async Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_or_strings: Single document string or list of document strings
|
input: Single document string or list of document strings
|
||||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||||
chunk_size, split the sub chunk by token size.
|
|
||||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||||
split_by_character is None, this parameter is ignored.
|
split_by_character is None, this parameter is ignored.
|
||||||
"""
|
"""
|
||||||
await self.apipeline_enqueue_documents(string_or_strings)
|
await self.apipeline_enqueue_documents(input)
|
||||||
await self.apipeline_process_enqueue_documents(
|
await self.apipeline_process_enqueue_documents(
|
||||||
split_by_character, split_by_character_only
|
split_by_character, split_by_character_only
|
||||||
)
|
)
|
||||||
@@ -677,7 +653,7 @@ class LightRAG:
|
|||||||
if update_storage:
|
if update_storage:
|
||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
|
async def apipeline_enqueue_documents(self, input: str | list[str]):
|
||||||
"""
|
"""
|
||||||
Pipeline for Processing Documents
|
Pipeline for Processing Documents
|
||||||
|
|
||||||
@@ -686,11 +662,11 @@ class LightRAG:
|
|||||||
3. Filter out already processed documents
|
3. Filter out already processed documents
|
||||||
4. Enqueue document in status
|
4. Enqueue document in status
|
||||||
"""
|
"""
|
||||||
if isinstance(string_or_strings, str):
|
if isinstance(input, str):
|
||||||
string_or_strings = [string_or_strings]
|
input = [input]
|
||||||
|
|
||||||
# 1. Remove duplicate contents from the list
|
# 1. Remove duplicate contents from the list
|
||||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
unique_contents = list(set(doc.strip() for doc in input))
|
||||||
|
|
||||||
# 2. Generate document IDs and initial status
|
# 2. Generate document IDs and initial status
|
||||||
new_docs: dict[str, Any] = {
|
new_docs: dict[str, Any] = {
|
||||||
@@ -857,32 +833,32 @@ class LightRAG:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _insert_done(self):
|
async def _insert_done(self):
|
||||||
tasks = []
|
tasks = [
|
||||||
for storage_inst in [
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||||
self.full_docs,
|
for storage_inst in [ # type: ignore
|
||||||
self.text_chunks,
|
self.full_docs,
|
||||||
self.llm_response_cache,
|
self.text_chunks,
|
||||||
self.entities_vdb,
|
self.llm_response_cache,
|
||||||
self.relationships_vdb,
|
self.entities_vdb,
|
||||||
self.chunks_vdb,
|
self.relationships_vdb,
|
||||||
self.chunk_entity_relation_graph,
|
self.chunks_vdb,
|
||||||
]:
|
self.chunk_entity_relation_graph,
|
||||||
if storage_inst is None:
|
]
|
||||||
continue
|
if storage_inst is not None
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def insert_custom_kg(self, custom_kg: dict):
|
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||||
|
|
||||||
async def ainsert_custom_kg(self, custom_kg: dict):
|
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||||
update_storage = False
|
update_storage = False
|
||||||
try:
|
try:
|
||||||
# Insert chunks into vector storage
|
# Insert chunks into vector storage
|
||||||
all_chunks_data = {}
|
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||||
chunk_to_source_map = {}
|
chunk_to_source_map: dict[str, str] = {}
|
||||||
for chunk_data in custom_kg.get("chunks", []):
|
for chunk_data in custom_kg.get("chunks", {}):
|
||||||
chunk_content = chunk_data["content"]
|
chunk_content = chunk_data["content"]
|
||||||
source_id = chunk_data["source_id"]
|
source_id = chunk_data["source_id"]
|
||||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||||
@@ -892,13 +868,13 @@ class LightRAG:
|
|||||||
chunk_to_source_map[source_id] = chunk_id
|
chunk_to_source_map[source_id] = chunk_id
|
||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
if self.chunks_vdb is not None and all_chunks_data:
|
if all_chunks_data:
|
||||||
await self.chunks_vdb.upsert(all_chunks_data)
|
await self.chunks_vdb.upsert(all_chunks_data)
|
||||||
if self.text_chunks is not None and all_chunks_data:
|
if all_chunks_data:
|
||||||
await self.text_chunks.upsert(all_chunks_data)
|
await self.text_chunks.upsert(all_chunks_data)
|
||||||
|
|
||||||
# Insert entities into knowledge graph
|
# Insert entities into knowledge graph
|
||||||
all_entities_data = []
|
all_entities_data: list[dict[str, str]] = []
|
||||||
for entity_data in custom_kg.get("entities", []):
|
for entity_data in custom_kg.get("entities", []):
|
||||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||||
@@ -914,7 +890,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare node data
|
# Prepare node data
|
||||||
node_data = {
|
node_data: dict[str, str] = {
|
||||||
"entity_type": entity_type,
|
"entity_type": entity_type,
|
||||||
"description": description,
|
"description": description,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -928,7 +904,7 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
# Insert relationships into knowledge graph
|
# Insert relationships into knowledge graph
|
||||||
all_relationships_data = []
|
all_relationships_data: list[dict[str, str]] = []
|
||||||
for relationship_data in custom_kg.get("relationships", []):
|
for relationship_data in custom_kg.get("relationships", []):
|
||||||
src_id = f'"{relationship_data["src_id"].upper()}"'
|
src_id = f'"{relationship_data["src_id"].upper()}"'
|
||||||
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
||||||
@@ -970,7 +946,7 @@ class LightRAG:
|
|||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
edge_data = {
|
edge_data: dict[str, str] = {
|
||||||
"src_id": src_id,
|
"src_id": src_id,
|
||||||
"tgt_id": tgt_id,
|
"tgt_id": tgt_id,
|
||||||
"description": description,
|
"description": description,
|
||||||
@@ -980,41 +956,68 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
# Insert entities into vector storage if needed
|
# Insert entities into vector storage if needed
|
||||||
if self.entities_vdb is not None:
|
data_for_vdb = {
|
||||||
data_for_vdb = {
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
"content": dp["entity_name"] + dp["description"],
|
||||||
"content": dp["entity_name"] + dp["description"],
|
"entity_name": dp["entity_name"],
|
||||||
"entity_name": dp["entity_name"],
|
|
||||||
}
|
|
||||||
for dp in all_entities_data
|
|
||||||
}
|
}
|
||||||
await self.entities_vdb.upsert(data_for_vdb)
|
for dp in all_entities_data
|
||||||
|
}
|
||||||
|
await self.entities_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
# Insert relationships into vector storage if needed
|
# Insert relationships into vector storage if needed
|
||||||
if self.relationships_vdb is not None:
|
data_for_vdb = {
|
||||||
data_for_vdb = {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
"src_id": dp["src_id"],
|
||||||
"src_id": dp["src_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"content": dp["keywords"]
|
||||||
"content": dp["keywords"]
|
+ dp["src_id"]
|
||||||
+ dp["src_id"]
|
+ dp["tgt_id"]
|
||||||
+ dp["tgt_id"]
|
+ dp["description"],
|
||||||
+ dp["description"],
|
|
||||||
}
|
|
||||||
for dp in all_relationships_data
|
|
||||||
}
|
}
|
||||||
await self.relationships_vdb.upsert(data_for_vdb)
|
for dp in all_relationships_data
|
||||||
|
}
|
||||||
|
await self.relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if update_storage:
|
if update_storage:
|
||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
|
def query(
|
||||||
|
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
||||||
|
) -> str | Iterator[str]:
|
||||||
|
"""
|
||||||
|
Perform a sync query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the query execution.
|
||||||
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.aquery(query, prompt, param))
|
|
||||||
|
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||||
|
|
||||||
async def aquery(
|
async def aquery(
|
||||||
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
|
self,
|
||||||
):
|
query: str,
|
||||||
|
param: QueryParam = QueryParam(),
|
||||||
|
prompt: str | None = None,
|
||||||
|
) -> str | AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Perform a async query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the query execution.
|
||||||
|
"""
|
||||||
if param.mode in ["local", "global", "hybrid"]:
|
if param.mode in ["local", "global", "hybrid"]:
|
||||||
response = await kg_query(
|
response = await kg_query(
|
||||||
query,
|
query,
|
||||||
@@ -1094,7 +1097,7 @@ class LightRAG:
|
|||||||
|
|
||||||
async def aquery_with_separate_keyword_extraction(
|
async def aquery_with_separate_keyword_extraction(
|
||||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||||
):
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||||
@@ -1117,8 +1120,8 @@ class LightRAG:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
param.hl_keywords = (hl_keywords,)
|
param.hl_keywords = hl_keywords
|
||||||
param.ll_keywords = (ll_keywords,)
|
param.ll_keywords = ll_keywords
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# STEP 2: Final Query Logic
|
# STEP 2: Final Query Logic
|
||||||
@@ -1146,7 +1149,7 @@ class LightRAG:
|
|||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=self.embedding_funcne,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
@@ -1195,12 +1198,7 @@ class LightRAG:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
async def _query_done(self):
|
async def _query_done(self):
|
||||||
tasks = []
|
await self.llm_response_cache.index_done_callback()
|
||||||
for storage_inst in [self.llm_response_cache]:
|
|
||||||
if storage_inst is None:
|
|
||||||
continue
|
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
def delete_by_entity(self, entity_name: str):
|
def delete_by_entity(self, entity_name: str):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
@@ -1222,16 +1220,16 @@ class LightRAG:
|
|||||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||||
|
|
||||||
async def _delete_by_entity_done(self):
|
async def _delete_by_entity_done(self):
|
||||||
tasks = []
|
await asyncio.gather(
|
||||||
for storage_inst in [
|
*[
|
||||||
self.entities_vdb,
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||||
self.relationships_vdb,
|
for storage_inst in [ # type: ignore
|
||||||
self.chunk_entity_relation_graph,
|
self.entities_vdb,
|
||||||
]:
|
self.relationships_vdb,
|
||||||
if storage_inst is None:
|
self.chunk_entity_relation_graph,
|
||||||
continue
|
]
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
]
|
||||||
await asyncio.gather(*tasks)
|
)
|
||||||
|
|
||||||
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
||||||
"""Get summary of document content
|
"""Get summary of document content
|
||||||
@@ -1256,7 +1254,7 @@ class LightRAG:
|
|||||||
"""
|
"""
|
||||||
return await self.doc_status.get_status_counts()
|
return await self.doc_status.get_status_counts()
|
||||||
|
|
||||||
async def adelete_by_doc_id(self, doc_id: str):
|
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||||
"""Delete a document and all its related data
|
"""Delete a document and all its related data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1273,6 +1271,9 @@ class LightRAG:
|
|||||||
|
|
||||||
# 2. Get all related chunks
|
# 2. Get all related chunks
|
||||||
chunks = await self.text_chunks.get_by_id(doc_id)
|
chunks = await self.text_chunks.get_by_id(doc_id)
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
chunk_ids = list(chunks.keys())
|
chunk_ids = list(chunks.keys())
|
||||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||||
|
|
||||||
@@ -1443,13 +1444,9 @@ class LightRAG:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while deleting document {doc_id}: {e}")
|
logger.error(f"Error while deleting document {doc_id}: {e}")
|
||||||
|
|
||||||
def delete_by_doc_id(self, doc_id: str):
|
|
||||||
"""Synchronous version of adelete"""
|
|
||||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
|
||||||
|
|
||||||
async def get_entity_info(
|
async def get_entity_info(
|
||||||
self, entity_name: str, include_vector_data: bool = False
|
self, entity_name: str, include_vector_data: bool = False
|
||||||
):
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of an entity
|
"""Get detailed information of an entity
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1469,7 +1466,7 @@ class LightRAG:
|
|||||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||||
source_id = node_data.get("source_id") if node_data else None
|
source_id = node_data.get("source_id") if node_data else None
|
||||||
|
|
||||||
result = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
"entity_name": entity_name,
|
"entity_name": entity_name,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
"graph_data": node_data,
|
"graph_data": node_data,
|
||||||
@@ -1483,21 +1480,6 @@ class LightRAG:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
|
||||||
"""Synchronous version of getting entity information
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entity_name: Entity name (no need for quotes)
|
|
||||||
include_vector_data: Whether to include data from the vector database
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import tracemalloc
|
|
||||||
|
|
||||||
tracemalloc.start()
|
|
||||||
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
|
||||||
finally:
|
|
||||||
tracemalloc.stop()
|
|
||||||
|
|
||||||
async def get_relation_info(
|
async def get_relation_info(
|
||||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||||
):
|
):
|
||||||
@@ -1525,7 +1507,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
source_id = edge_data.get("source_id") if edge_data else None
|
source_id = edge_data.get("source_id") if edge_data else None
|
||||||
|
|
||||||
result = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
"src_entity": src_entity,
|
"src_entity": src_entity,
|
||||||
"tgt_entity": tgt_entity,
|
"tgt_entity": tgt_entity,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -1539,23 +1521,3 @@ class LightRAG:
|
|||||||
result["vector_data"] = vector_data[0] if vector_data else None
|
result["vector_data"] = vector_data[0] if vector_data else None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_relation_info_sync(
|
|
||||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
|
||||||
):
|
|
||||||
"""Synchronous version of getting relationship information
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src_entity: Source entity name (no need for quotes)
|
|
||||||
tgt_entity: Target entity name (no need for quotes)
|
|
||||||
include_vector_data: Whether to include data from the vector database
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import tracemalloc
|
|
||||||
|
|
||||||
tracemalloc.start()
|
|
||||||
return asyncio.run(
|
|
||||||
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
tracemalloc.stop()
|
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
from typing import List, Dict, Callable, Any
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Callable, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +25,7 @@ class Model(BaseModel):
|
|||||||
...,
|
...,
|
||||||
description="A function that generates the response from the llm. The response must be a string",
|
description="A function that generates the response from the llm. The response must be a string",
|
||||||
)
|
)
|
||||||
kwargs: Dict[str, Any] = Field(
|
kwargs: dict[str, Any] = Field(
|
||||||
...,
|
...,
|
||||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||||
)
|
)
|
||||||
@@ -57,7 +59,7 @@ class MultiModel:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, models: List[Model]):
|
def __init__(self, models: list[Model]):
|
||||||
self._models = models
|
self._models = models
|
||||||
self._current_model = 0
|
self._current_model = 0
|
||||||
|
|
||||||
@@ -66,7 +68,11 @@ class MultiModel:
|
|||||||
return self._models[self._current_model]
|
return self._models[self._current_model]
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] = [],
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from typing import Any, Union
|
from typing import Any, AsyncIterator
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from .utils import (
|
from .utils import (
|
||||||
logger,
|
logger,
|
||||||
@@ -36,7 +38,7 @@ import time
|
|||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str,
|
content: str,
|
||||||
split_by_character: Union[str, None] = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
overlap_token_size: int = 128,
|
overlap_token_size: int = 128,
|
||||||
max_token_size: int = 1024,
|
max_token_size: int = 1024,
|
||||||
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||||
already_weights.append(already_edge["weight"])
|
# Handle the case where get_edge returns None or missing fields
|
||||||
already_source_ids.extend(
|
if already_edge:
|
||||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
# Get weight with default 0.0 if missing
|
||||||
)
|
if "weight" in already_edge:
|
||||||
already_description.append(already_edge["description"])
|
already_weights.append(already_edge["weight"])
|
||||||
already_keywords.extend(
|
else:
|
||||||
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
logger.warning(
|
||||||
)
|
f"Edge between {src_id} and {tgt_id} missing weight field"
|
||||||
|
)
|
||||||
|
already_weights.append(0.0)
|
||||||
|
|
||||||
|
# Get source_id with empty string default if missing or None
|
||||||
|
if "source_id" in already_edge and already_edge["source_id"] is not None:
|
||||||
|
already_source_ids.extend(
|
||||||
|
split_string_by_multi_markers(
|
||||||
|
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get description with empty string default if missing or None
|
||||||
|
if (
|
||||||
|
"description" in already_edge
|
||||||
|
and already_edge["description"] is not None
|
||||||
|
):
|
||||||
|
already_description.append(already_edge["description"])
|
||||||
|
|
||||||
|
# Get keywords with empty string default if missing or None
|
||||||
|
if "keywords" in already_edge and already_edge["keywords"] is not None:
|
||||||
|
already_keywords.extend(
|
||||||
|
split_string_by_multi_markers(
|
||||||
|
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process edges_data with None checks
|
||||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||||
description = GRAPH_FIELD_SEP.join(
|
description = GRAPH_FIELD_SEP.join(
|
||||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
sorted(
|
||||||
|
set(
|
||||||
|
[dp["description"] for dp in edges_data if dp.get("description")]
|
||||||
|
+ already_description
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
keywords = GRAPH_FIELD_SEP.join(
|
keywords = GRAPH_FIELD_SEP.join(
|
||||||
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
sorted(
|
||||||
|
set(
|
||||||
|
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
||||||
|
+ already_keywords
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
source_id = GRAPH_FIELD_SEP.join(
|
source_id = GRAPH_FIELD_SEP.join(
|
||||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
set(
|
||||||
|
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||||||
|
+ already_source_ids
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
await knowledge_graph_inst.upsert_node(
|
await knowledge_graph_inst.upsert_node(
|
||||||
@@ -295,9 +337,9 @@ async def extract_entities(
|
|||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
llm_response_cache: BaseKVStorage = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> Union[BaseGraphStorage, None]:
|
) -> BaseGraphStorage | None:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||||
@@ -563,15 +605,15 @@ async def extract_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
query,
|
query: str,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
prompt: str = "",
|
prompt: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
@@ -684,8 +726,8 @@ async def kg_query(
|
|||||||
async def extract_keywords_only(
|
async def extract_keywords_only(
|
||||||
text: str,
|
text: str,
|
||||||
param: QueryParam,
|
param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||||
@@ -784,9 +826,9 @@ async def mix_kg_vector_query(
|
|||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||||
|
|
||||||
@@ -1551,13 +1593,13 @@ def combine_contexts(entities, relationships, sources):
|
|||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query,
|
query: str,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
):
|
) -> str | AsyncIterator[str]:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
@@ -1664,9 +1706,9 @@ async def kg_query_with_keywords(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Refactored kg_query that does NOT extract keywords by itself.
|
Refactored kg_query that does NOT extract keywords by itself.
|
||||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
GRAPH_FIELD_SEP = "<SEP>"
|
GRAPH_FIELD_SEP = "<SEP>"
|
||||||
|
|
||||||
PROMPTS = {}
|
PROMPTS = {}
|
||||||
|
@@ -1,26 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Dict, Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordExtractionFormat(BaseModel):
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
high_level_keywords: List[str]
|
high_level_keywords: list[str]
|
||||||
low_level_keywords: List[str]
|
low_level_keywords: list[str]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphNode(BaseModel):
|
class KnowledgeGraphNode(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
labels: List[str]
|
labels: list[str]
|
||||||
properties: Dict[str, Any] # anything else goes here
|
properties: dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphEdge(BaseModel):
|
class KnowledgeGraphEdge(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
type: str
|
type: Optional[str]
|
||||||
source: str # id of source node
|
source: str # id of source node
|
||||||
target: str # id of target node
|
target: str # id of target node
|
||||||
properties: Dict[str, Any] # anything else goes here
|
properties: dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class KnowledgeGraph(BaseModel):
|
||||||
nodes: List[KnowledgeGraphNode] = []
|
nodes: list[KnowledgeGraphNode] = []
|
||||||
edges: List[KnowledgeGraphEdge] = []
|
edges: list[KnowledgeGraphEdge] = []
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import io
|
import io
|
||||||
@@ -9,7 +11,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union, List, Optional
|
from typing import Any, Callable
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import bs4
|
import bs4
|
||||||
|
|
||||||
@@ -67,12 +69,12 @@ class EmbeddingFunc:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReasoningResponse:
|
class ReasoningResponse:
|
||||||
reasoning_content: str
|
reasoning_content: str | None
|
||||||
response_content: str
|
response_content: str
|
||||||
tag: str
|
tag: str
|
||||||
|
|
||||||
|
|
||||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
def locate_json_string_body_from_string(content: str) -> str | None:
|
||||||
"""Locate the JSON string body from a string"""
|
"""Locate the JSON string body from a string"""
|
||||||
try:
|
try:
|
||||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||||
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
|||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args, cache_type: str = None) -> str:
|
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||||
"""Compute a hash for the given arguments.
|
"""Compute a hash for the given arguments.
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to hash
|
*args: Arguments to hash
|
||||||
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
|||||||
return hashlib.md5(args_str.encode()).hexdigest()
|
return hashlib.md5(args_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content, prefix: str = ""):
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Compute a unique ID for a given content string.
|
||||||
|
|
||||||
|
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
||||||
|
"""
|
||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
|
|||||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||||
|
|
||||||
|
|
||||||
def is_float_regex(value):
|
def is_float_regex(value: str) -> bool:
|
||||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
def truncate_list_by_token_size(
|
||||||
|
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
||||||
|
) -> list[int]:
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
return []
|
return []
|
||||||
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = csv.writer(
|
writer = csv.writer(
|
||||||
output,
|
output,
|
||||||
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
|||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
||||||
# Clean the string by removing NUL characters
|
# Clean the string by removing NUL characters
|
||||||
cleaned_string = csv_string.replace("\0", "")
|
cleaned_string = csv_string.replace("\0", "")
|
||||||
|
|
||||||
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_combine_contexts(hl, ll):
|
def process_combine_contexts(hl: str, ll: str):
|
||||||
header = None
|
header = None
|
||||||
list_hl = csv_string_to_list(hl.strip())
|
list_hl = csv_string_to_list(hl.strip())
|
||||||
list_ll = csv_string_to_list(ll.strip())
|
list_ll = csv_string_to_list(ll.strip())
|
||||||
@@ -375,7 +384,7 @@ async def get_best_cached_response(
|
|||||||
llm_func=None,
|
llm_func=None,
|
||||||
original_prompt=None,
|
original_prompt=None,
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
) -> Union[str, None]:
|
) -> str | None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||||
)
|
)
|
||||||
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
|
|||||||
return dot_product / (norm1 * norm2)
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
||||||
"""Quantize embedding to specified bits"""
|
"""Quantize embedding to specified bits"""
|
||||||
# Convert list to numpy array if needed
|
# Convert list to numpy array if needed
|
||||||
if isinstance(embedding, list):
|
if isinstance(embedding, list):
|
||||||
@@ -570,9 +579,9 @@ class CacheData:
|
|||||||
args_hash: str
|
args_hash: str
|
||||||
content: str
|
content: str
|
||||||
prompt: str
|
prompt: str
|
||||||
quantized: Optional[np.ndarray] = None
|
quantized: np.ndarray | None = None
|
||||||
min_val: Optional[float] = None
|
min_val: float | None = None
|
||||||
max_val: Optional[float] = None
|
max_val: float | None = None
|
||||||
mode: str = "default"
|
mode: str = "default"
|
||||||
cache_type: str = "query"
|
cache_type: str = "query"
|
||||||
|
|
||||||
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
|
def get_conversation_turns(
|
||||||
|
conversation_history: list[dict[str, Any]], num_turns: int
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process conversation history to get the specified number of complete turns.
|
Process conversation history to get the specified number of complete turns.
|
||||||
|
|
||||||
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
Formatted string of the conversation history
|
Formatted string of the conversation history
|
||||||
"""
|
"""
|
||||||
# Group messages into turns
|
# Group messages into turns
|
||||||
turns = []
|
turns: list[list[dict[str, Any]]] = []
|
||||||
messages = []
|
messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# First, filter out keyword extraction messages
|
# First, filter out keyword extraction messages
|
||||||
for msg in conversation_history:
|
for msg in conversation_history:
|
||||||
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
turns = turns[-num_turns:]
|
turns = turns[-num_turns:]
|
||||||
|
|
||||||
# Format the turns into a string
|
# Format the turns into a string
|
||||||
formatted_turns = []
|
formatted_turns: list[str] = []
|
||||||
for turn in turns:
|
for turn in turns:
|
||||||
formatted_turns.extend(
|
formatted_turns.extend(
|
||||||
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
||||||
|
@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
|
|||||||
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
||||||
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
||||||
<PropertyRow name={'Id'} value={edge.id} />
|
<PropertyRow name={'Id'} value={edge.id} />
|
||||||
<PropertyRow name={'Type'} value={edge.type} />
|
{edge.type && <PropertyRow name={'Type'} value={edge.type} />}
|
||||||
<PropertyRow
|
<PropertyRow
|
||||||
name={'Source'}
|
name={'Source'}
|
||||||
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
||||||
|
@@ -24,7 +24,7 @@ const validateGraph = (graph: RawGraph) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (const edge of graph.edges) {
|
for (const edge of graph.edges) {
|
||||||
if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) {
|
if (!edge.id || !edge.source || !edge.target) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
|
|||||||
if (source !== undefined && source !== undefined) {
|
if (source !== undefined && source !== undefined) {
|
||||||
const sourceNode = rawData.nodes[source]
|
const sourceNode = rawData.nodes[source]
|
||||||
const targetNode = rawData.nodes[target]
|
const targetNode = rawData.nodes[target]
|
||||||
|
if (!sourceNode) {
|
||||||
|
console.error(`Source node ${edge.source} is undefined`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (!targetNode) {
|
||||||
|
console.error(`Target node ${edge.target} is undefined`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
sourceNode.degree += 1
|
sourceNode.degree += 1
|
||||||
targetNode.degree += 1
|
targetNode.degree += 1
|
||||||
}
|
}
|
||||||
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
|||||||
|
|
||||||
for (const rawEdge of rawGraph?.edges ?? []) {
|
for (const rawEdge of rawGraph?.edges ?? []) {
|
||||||
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
||||||
label: rawEdge.type
|
label: rawEdge.type || undefined
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -19,7 +19,7 @@ export type RawEdgeType = {
|
|||||||
id: string
|
id: string
|
||||||
source: string
|
source: string
|
||||||
target: string
|
target: string
|
||||||
type: string
|
type?: string
|
||||||
properties: Record<string, any>
|
properties: Record<string, any>
|
||||||
|
|
||||||
dynamicId: string
|
dynamicId: string
|
||||||
|
Reference in New Issue
Block a user