Merge remote-tracking branch 'origin/main' into refactor-api-server

This commit is contained in:
yangdx
2025-02-21 11:24:16 +08:00
40 changed files with 1393 additions and 592 deletions

61
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: Bug Report
description: File a bug report
title: "[Bug]: <title>"
labels: ["bug", "triage"]
body:
- type: checkboxes
id: existingcheck
attributes:
label: Do you need to file an issue?
description: Please help us manage our time by avoiding duplicates and common bugs with the steps below.
options:
- label: I have searched the existing issues and this bug is not already filed.
- label: I believe this is a legitimate bug, not just a question or feature request.
- type: textarea
id: description
attributes:
label: Describe the bug
description: A clear and concise description of what the bug is.
placeholder: What went wrong?
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: Steps to reproduce the behavior.
placeholder: How can we replicate the issue?
- type: textarea
id: expected_behavior
attributes:
label: Expected Behavior
description: A clear and concise description of what you expected to happen.
placeholder: What should have happened?
- type: textarea
id: configused
attributes:
label: LightRAG Config Used
description: The LightRAG configuration used for the run.
placeholder: The settings content or LightRAG configuration
value: |
# Paste your config here
- type: textarea
id: screenshotslogs
attributes:
label: Logs and screenshots
description: If applicable, add screenshots and logs to help explain your problem.
placeholder: Add logs and screenshots here
- type: textarea
id: additional_information
attributes:
label: Additional Information
description: |
- LightRAG Version: e.g., v0.1.1
- Operating System: e.g., Windows 10, Ubuntu 20.04
- Python Version: e.g., 3.8
- Related Issues: e.g., #1
- Any other relevant information.
value: |
- LightRAG Version:
- Operating System:
- Python Version:
- Related Issues:

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1 @@
blank_issues_enabled: false

View File

@@ -0,0 +1,26 @@
name: Feature Request
description: File a feature request
labels: ["enhancement"]
title: "[Feature Request]: <title>"
body:
- type: checkboxes
id: existingcheck
attributes:
label: Do you need to file a feature request?
description: Please help us manage our time by avoiding duplicates and common feature request with the steps below.
options:
- label: I have searched the existing feature request and this feature request is not already filed.
- label: I believe this is a legitimate feature request, not just a question or bug.
- type: textarea
id: feature_request_description
attributes:
label: Feature Request Description
description: A clear and concise description of the feature request you would like.
placeholder: What this feature request add more or improve?
- type: textarea
id: additional_context
attributes:
label: Additional Context
description: Add any other context or screenshots about the feature request here.
placeholder: Any additional information

26
.github/ISSUE_TEMPLATE/question.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Question
description: Ask a general question
labels: ["question"]
title: "[Question]: <title>"
body:
- type: checkboxes
id: existingcheck
attributes:
label: Do you need to ask a question?
description: Please help us manage our time by avoiding duplicates and common questions with the steps below.
options:
- label: I have searched the existing question and discussions and this question is not already answered.
- label: I believe this is a legitimate question, not just a bug or feature request.
- type: textarea
id: question
attributes:
label: Your Question
description: A clear and concise description of your question.
placeholder: What is your question?
- type: textarea
id: context
attributes:
label: Additional Context
description: Provide any additional context or details that might help us understand your question better.
placeholder: Add any relevant information here

32
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,32 @@
<!--
Thanks for contributing to LightRAG!
Please ensure your pull request is ready for review before submitting.
About this template
This template helps contributors provide a clear and concise description of their changes. Feel free to adjust it as needed.
-->
## Description
[Briefly describe the changes made in this pull request.]
## Related Issues
[Reference any related issues or tasks addressed by this pull request.]
## Changes Made
[List the specific changes made in this pull request.]
## Checklist
- [ ] Changes tested locally
- [ ] Code reviewed
- [ ] Documentation updated (if necessary)
- [ ] Unit tests added (if applicable)
## Additional Notes
[Add any additional notes or context for the reviewer(s).]

View File

@@ -312,7 +312,41 @@ rag = LightRAG(
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`. In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
</details> </details>
<details>
<summary> <b>LlamaIndex</b> </summary>
LightRAG supports integration with LlamaIndex.
1. **LlamaIndex** (`llm/llama_index_impl.py`):
- Integrates with OpenAI and other providers through LlamaIndex
- See [LlamaIndex Documentation](lightrag/llm/Readme.md) for detailed setup and examples
### Example Usage
```python
# Using LlamaIndex with direct OpenAI access
from lightrag import LightRAG
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
rag = LightRAG(
working_dir="your/path",
llm_model_func=llama_index_complete_if_cache, # LlamaIndex-compatible completion function
embedding_func=EmbeddingFunc( # LlamaIndex-compatible embedding function
embedding_dim=1536,
max_token_size=8192,
func=lambda texts: llama_index_embed(texts, embed_model=embed_model)
),
)
```
#### For detailed documentation and examples, see:
- [LlamaIndex Documentation](lightrag/llm/Readme.md)
- [Direct OpenAI Example](examples/lightrag_llamaindex_direct_demo.py)
- [LiteLLM Proxy Example](examples/lightrag_llamaindex_litellm_demo.py)
</details>
<details> <details>
<summary> <b>Conversation History Support</b> </summary> <summary> <b>Conversation History Support</b> </summary>

View File

@@ -1,5 +1,3 @@
version: '3.8'
services: services:
lightrag: lightrag:
build: . build: .

View File

@@ -87,18 +87,27 @@ custom_kg = {
{ {
"content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.", "content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.",
"source_id": "Source1", "source_id": "Source1",
"source_chunk_index": 0,
},
{
"content": "One outstanding feature of ProductX is its advanced AI capabilities.",
"source_id": "Source1",
"chunk_order_index": 1,
}, },
{ {
"content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.", "content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.",
"source_id": "Source2", "source_id": "Source2",
"source_chunk_index": 0,
}, },
{ {
"content": "EventY, held in CityC, attracts technology enthusiasts and companies from around the globe.", "content": "EventY, held in CityC, attracts technology enthusiasts and companies from around the globe.",
"source_id": "Source3", "source_id": "Source3",
"source_chunk_index": 0,
}, },
{ {
"content": "None", "content": "None",
"source_id": "UNKNOWN", "source_id": "UNKNOWN",
"source_chunk_index": 0,
}, },
], ],
} }

View File

@@ -98,7 +98,6 @@ async def init():
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,

View File

@@ -0,0 +1,113 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.llama_index_impl import (
llama_index_complete_if_cache,
llama_index_embed,
)
from lightrag.utils import EmbeddingFunc
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
import asyncio
# Configure working directory
WORKING_DIR = "./index_default"
print(f"WORKING_DIR: {WORKING_DIR}")
# Model configuration
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
# OpenAI configuration
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key-here")
if not os.path.exists(WORKING_DIR):
print(f"Creating working directory: {WORKING_DIR}")
os.mkdir(WORKING_DIR)
# Initialize LLM function
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize OpenAI if not in kwargs
if "llm_instance" not in kwargs:
llm_instance = OpenAI(
model=LLM_MODEL,
api_key=OPENAI_API_KEY,
temperature=0.7,
)
kwargs["llm_instance"] = llm_instance
response = await llama_index_complete_if_cache(
kwargs["llm_instance"],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
print(f"LLM request failed: {str(e)}")
raise
# Initialize embedding function
async def embedding_func(texts):
try:
embed_model = OpenAIEmbedding(
model=EMBEDDING_MODEL,
api_key=OPENAI_API_KEY,
)
return await llama_index_embed(texts, embed_model=embed_model)
except Exception as e:
print(f"Embedding failed: {str(e)}")
raise
# Get embedding dimension
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"embedding_dim={embedding_dim}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
# Insert example text
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Test different query modes
print("\nNaive Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
print("\nLocal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
print("\nGlobal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
print("\nHybrid Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -0,0 +1,116 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.llama_index_impl import (
llama_index_complete_if_cache,
llama_index_embed,
)
from lightrag.utils import EmbeddingFunc
from llama_index.llms.litellm import LiteLLM
from llama_index.embeddings.litellm import LiteLLMEmbedding
import asyncio
# Configure working directory
WORKING_DIR = "./index_default"
print(f"WORKING_DIR: {WORKING_DIR}")
# Model configuration
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
# LiteLLM configuration
LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000")
print(f"LITELLM_URL: {LITELLM_URL}")
LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-1234")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# Initialize LLM function
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize LiteLLM if not in kwargs
if "llm_instance" not in kwargs:
llm_instance = LiteLLM(
model=f"openai/{LLM_MODEL}", # Format: "provider/model_name"
api_base=LITELLM_URL,
api_key=LITELLM_KEY,
temperature=0.7,
)
kwargs["llm_instance"] = llm_instance
response = await llama_index_complete_if_cache(
kwargs["llm_instance"],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
print(f"LLM request failed: {str(e)}")
raise
# Initialize embedding function
async def embedding_func(texts):
try:
embed_model = LiteLLMEmbedding(
model_name=f"openai/{EMBEDDING_MODEL}",
api_base=LITELLM_URL,
api_key=LITELLM_KEY,
)
return await llama_index_embed(texts, embed_model=embed_model)
except Exception as e:
print(f"Embedding failed: {str(e)}")
raise
# Get embedding dimension
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"embedding_dim={embedding_dim}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
# Insert example text
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Test different query modes
print("\nNaive Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
print("\nLocal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
print("\nGlobal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
print("\nHybrid Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,9 +1,8 @@
import os
import inspect import inspect
import os
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.llm import openai_complete, openai_embed from lightrag.llm import openai_complete, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
from lightrag.lightrag import always_get_an_event_loop
from lightrag import QueryParam from lightrag import QueryParam
# WorkingDir # WorkingDir

View File

@@ -63,7 +63,6 @@ async def main():
# Initialize LightRAG # Initialize LightRAG
# We use TiDB DB as the KV/vector # We use TiDB DB as the KV/vector
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,

View File

@@ -70,7 +70,7 @@ def main():
), ),
vector_storage="FaissVectorDBStorage", vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": 0.3 # Your desired threshold "cosine_better_than_threshold": 0.2 # Your desired threshold
}, },
) )

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.7" __version__ = "1.1.11"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1 +1,157 @@
# print ("init package vars here. ......") STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
}
def verify_storage_implementation(storage_type: str, storage_name: str) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)

View File

@@ -34,14 +34,9 @@ if not pm.is_installed("psycopg-pool"):
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
try: import psycopg
import psycopg from psycopg.rows import namedtuple_row
from psycopg.rows import namedtuple_row from psycopg_pool import AsyncConnectionPool, PoolTimeout
from psycopg_pool import AsyncConnectionPool, PoolTimeout
except ImportError:
raise ImportError(
"`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`."
)
class AGEQueryException(Exception): class AGEQueryException(Exception):

View File

@@ -10,13 +10,8 @@ import pipmaster as pm
if not pm.is_installed("chromadb"): if not pm.is_installed("chromadb"):
pm.install("chromadb") pm.install("chromadb")
try: from chromadb import HttpClient, PersistentClient
from chromadb import HttpClient, PersistentClient from chromadb.config import Settings
from chromadb.config import Settings
except ImportError as e:
raise ImportError(
"`chromadb` library is not installed. Please install it via pip: `pip install chromadb`."
) from e
@final @final
@@ -113,9 +108,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
raise raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
logger.warning("Empty data provided to vector DB") return
return []
try: try:
ids = list(data.keys()) ids = list(data.keys())

View File

@@ -20,12 +20,7 @@ from lightrag.base import (
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
pm.install("faiss") pm.install("faiss")
try: import faiss
import faiss
except ImportError as e:
raise ImportError(
"`faiss` library is not installed. Please install it via pip: `pip install faiss`."
) from e
@final @final
@@ -84,10 +79,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
... ...
} }
""" """
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
logger.warning("You are inserting empty data to the vector DB") return
return []
current_time = time.time() current_time = time.time()

View File

@@ -2,6 +2,7 @@ import asyncio
import inspect import inspect
import json import json
import os import os
import pipmaster as pm
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, final from typing import Any, Dict, List, final
@@ -20,14 +21,12 @@ from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
try: if not pm.is_installed("gremlinpython"):
from gremlin_python.driver import client, serializer pm.install("gremlinpython")
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError from gremlin_python.driver import client, serializer
except ImportError as e: from gremlin_python.driver.aiohttp.transport import AiohttpTransport
raise ImportError( from gremlin_python.driver.protocol import GremlinServerError
"`gremlin` library is not installed. Please install it via pip: `pip install gremlin`."
) from e
@final @final

View File

@@ -67,6 +67,10 @@ class JsonDocStatusStorage(DocStatusStorage):
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
self._data.update(data) self._data.update(data)
await self.index_done_callback() await self.index_done_callback()

View File

@@ -43,6 +43,9 @@ class JsonKVStorage(BaseKVStorage):
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)

View File

@@ -14,13 +14,8 @@ if not pm.is_installed("configparser"):
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
pm.install("pymilvus") pm.install("pymilvus")
try: import configparser
import configparser from pymilvus import MilvusClient
from pymilvus import MilvusClient
except ImportError as e:
raise ImportError(
"`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -80,11 +75,11 @@ class MilvusVectorDBStorage(BaseVectorStorage):
) )
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not len(data): if not data:
logger.warning("You insert an empty data to vector DB") return
return []
list_data = [ list_data: list[dict[str, Any]] = [
{ {
"id": k, "id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},

View File

@@ -25,18 +25,13 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"): if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
try: from motor.motor_asyncio import (
from motor.motor_asyncio import ( AsyncIOMotorClient,
AsyncIOMotorClient, AsyncIOMotorDatabase,
AsyncIOMotorDatabase, AsyncIOMotorCollection,
AsyncIOMotorCollection, )
) from pymongo.operations import SearchIndexModel
from pymongo.operations import SearchIndexModel from pymongo.errors import PyMongoError
from pymongo.errors import PyMongoError
except ImportError as e:
raise ImportError(
"`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -113,8 +108,12 @@ class MongoKVStorage(BaseKVStorage):
return keys - existing_ids return keys - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
update_tasks = [] update_tasks: list[Any] = []
for mode, items in data.items(): for mode, items in data.items():
for k, v in items.items(): for k, v in items.items():
key = f"{mode}_{k}" key = f"{mode}_{k}"
@@ -186,7 +185,10 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
update_tasks = [] logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
update_tasks: list[Any] = []
for k, v in data.items(): for k, v in data.items():
data[k]["_id"] = k data[k]["_id"] = k
update_tasks.append( update_tasks.append(
@@ -860,10 +862,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug("vector index already exist") logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
logger.warning("You are inserting an empty data set to vector DB") return
return []
list_data = [ list_data = [
{ {

View File

@@ -18,12 +18,7 @@ from lightrag.base import (
if not pm.is_installed("nano-vectordb"): if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") pm.install("nano-vectordb")
try: from nano_vectordb import NanoVectorDB
from nano_vectordb import NanoVectorDB
except ImportError as e:
raise ImportError(
"`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`."
) from e
@final @final
@@ -50,10 +45,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not len(data): if not data:
logger.warning("You insert an empty data to vector DB") return
return []
current_time = time.time() current_time = time.time()
list_data = [ list_data = [

View File

@@ -23,18 +23,13 @@ import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
try: from neo4j import (
from neo4j import ( AsyncGraphDatabase,
AsyncGraphDatabase, exceptions as neo4jExceptions,
exceptions as neo4jExceptions, AsyncDriver,
AsyncDriver, AsyncManagedTransaction,
AsyncManagedTransaction, GraphDatabase,
GraphDatabase, )
)
except ImportError as e:
raise ImportError(
"`neo4j` library is not installed. Please install it via pip: `pip install neo4j`."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")

View File

@@ -17,16 +17,12 @@ import pipmaster as pm
if not pm.is_installed("networkx"): if not pm.is_installed("networkx"):
pm.install("networkx") pm.install("networkx")
if not pm.is_installed("graspologic"): if not pm.is_installed("graspologic"):
pm.install("graspologic") pm.install("graspologic")
try: import networkx as nx
from graspologic import embed from graspologic import embed
import networkx as nx
except ImportError as e:
raise ImportError(
"`networkx` library is not installed. Please install it via pip: `pip install networkx`."
) from e
@final @final

View File

@@ -26,14 +26,8 @@ if not pm.is_installed("graspologic"):
if not pm.is_installed("oracledb"): if not pm.is_installed("oracledb"):
pm.install("oracledb") pm.install("oracledb")
try: from graspologic import embed
from graspologic import embed import oracledb
import oracledb
except ImportError as e:
raise ImportError(
"`oracledb` library is not installed. Please install it via pip: `pip install oracledb`."
) from e
class OracleDB: class OracleDB:
@@ -51,7 +45,7 @@ class OracleDB:
self.increment = 1 self.increment = 1
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
if self.user is None or self.password is None: if self.user is None or self.password is None:
raise ValueError("Missing database user or password in addon_params") raise ValueError("Missing database user or password")
try: try:
oracledb.defaults.fetch_lobs = False oracledb.defaults.fetch_lobs = False
@@ -332,6 +326,10 @@ class OracleKVStorage(BaseKVStorage):
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [ list_data = [
{ {

View File

@@ -38,14 +38,8 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
try: import asyncpg
import asyncpg from asyncpg import Pool
from asyncpg import Pool
except ImportError as e:
raise ImportError(
"`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`."
) from e
class PostgreSQLDB: class PostgreSQLDB:
@@ -61,9 +55,7 @@ class PostgreSQLDB:
self.pool: Pool | None = None self.pool: Pool | None = None
if self.user is None or self.password is None or self.database is None: if self.user is None or self.password is None or self.database is None:
raise ValueError( raise ValueError("Missing database user, password, or database")
"Missing database user, password, or database in addon_params"
)
async def initdb(self): async def initdb(self):
try: try:
@@ -353,6 +345,10 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass pass
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -454,10 +450,10 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data return upsert_sql, data
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not len(data): if not data:
logger.warning("You insert an empty data to vector DB") return
return []
current_time = time.time() current_time = time.time()
list_data = [ list_data = [
{ {
@@ -618,6 +614,10 @@ class PGDocStatusStorage(DocStatusStorage):
Args: Args:
data: dictionary of document IDs and their status data data: dictionary of document IDs and their status data
""" """
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status) sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7) values($1,$2,$3,$4,$5,$6,$7)
on conflict(id,workspace) do update set on conflict(id,workspace) do update set

View File

@@ -15,16 +15,10 @@ config.read("config.ini", "utf-8")
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("qdrant_client"): if not pm.is_installed("qdrant-client"):
pm.install("qdrant_client") pm.install("qdrant-client")
try: from qdrant_client import QdrantClient, models
from qdrant_client import QdrantClient, models
except ImportError:
raise ImportError(
"`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`."
)
def compute_mdhash_id_for_qdrant( def compute_mdhash_id_for_qdrant(
@@ -93,9 +87,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not len(data): logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.warning("You insert an empty data to vector DB") if not data:
return [] return
list_data = [ list_data = [
{ {
"id": k, "id": k,

View File

@@ -49,6 +49,9 @@ class RedisKVStorage(BaseKVStorage):
return set(keys) - existing_ids return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for k, v in data.items(): for k, v in data.items():

View File

@@ -20,13 +20,7 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"): if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy") pm.install("sqlalchemy")
try: from sqlalchemy import create_engine, text
from sqlalchemy import create_engine, text
except ImportError as e:
raise ImportError(
"`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`."
) from e
class TiDB: class TiDB:
@@ -217,6 +211,9 @@ class TiDBKVStorage(BaseKVStorage):
################ INSERT full_doc AND chunks ################ ################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -324,12 +321,12 @@ class TiDBVectorDBStorage(BaseVectorStorage):
###### INSERT entities And relationships ###### ###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
# ignore, upsert in TiDBKVStorage already logger.info(f"Inserting {len(data)} to {self.namespace}")
if not len(data): if not data:
logger.warning("You insert an empty data to vector DB") return
return []
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
return [] return
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
list_data = [ list_data = [

View File

@@ -6,7 +6,13 @@ import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast from typing import Any, AsyncIterator, Callable, Iterator, cast, final
from lightrag.kg import (
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation,
)
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -32,221 +38,37 @@ from .operate import (
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc,
always_get_an_event_loop,
compute_mdhash_id, compute_mdhash_id,
convert_response_to_json, convert_response_to_json,
lazy_external_import,
limit_async_func_call, limit_async_func_call,
logger, logger,
set_logger, set_logger,
encode_string_by_tiktoken,
) )
from .types import KnowledgeGraph from .types import KnowledgeGraph
# TODO: TO REMOVE @Yannick
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
}
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
@final
@dataclass @dataclass
class LightRAG: class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation.""" """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
# Directory
# ---
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
"""Directory where cache and temporary files are stored.""" """Directory where cache and temporary files are stored."""
embedding_cache_config: dict[str, Any] = field( # Storage
default_factory=lambda: { # ---
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
kv_storage: str = field(default="JsonKVStorage") kv_storage: str = field(default="JsonKVStorage")
"""Storage backend for key-value data.""" """Storage backend for key-value data."""
@@ -261,32 +83,74 @@ class LightRAG:
"""Storage type for tracking document processing statuses.""" """Storage type for tracking document processing statuses."""
# Logging # Logging
current_log_level = logger.level # ---
log_level: int = field(default=current_log_level)
log_level: int = field(default=logger.level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_dir: str = field(default=os.getcwd()) log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
"""Directory where logs are stored. Defaults to the current working directory.""" """Log file path."""
# Text chunking
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini"
"""Model name used for tokenization when chunking text."""
# Entity extraction # Entity extraction
entity_extract_max_gleaning: int = 1 # ---
entity_extract_max_gleaning: int = field(default=1)
"""Maximum number of entity extraction attempts for ambiguous content.""" """Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500")) entity_summary_to_max_tokens: int = field(
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
)
# Text chunking
# ---
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = field(
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
)
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text."""
"""Maximum number of tokens used for summarizing extracted entities.""" """Maximum number of tokens used for summarizing extracted entities."""
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters:
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
- `content`: The text content of the chunk.
Defaults to `chunking_by_token_size` if not specified.
"""
# Node embedding # Node embedding
node_embedding_algorithm: str = "node2vec" # ---
node_embedding_algorithm: str = field(default="node2vec")
"""Algorithm used for node embedding in knowledge graphs.""" """Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field( node2vec_params: dict[str, int] = field(
@@ -308,116 +172,102 @@ class LightRAG:
- random_seed: Seed value for reproducibility. - random_seed: Seed value for reproducibility.
""" """
embedding_func: EmbeddingFunc | None = None # Embedding
# ---
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use.""" """Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32 embedding_batch_num: int = field(default=32)
"""Batch size for embedding computations.""" """Batch size for embedding computations."""
embedding_func_max_async: int = 16 embedding_func_max_async: int = field(default=16)
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
# LLM Configuration # LLM Configuration
llm_model_func: Callable[..., object] | None = None # ---
llm_model_func: Callable[..., object] | None = field(default=None)
"""Function for interacting with the large language model (LLM). Must be set before use.""" """Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" llm_model_name: str = field(default="gpt-4o-mini")
"""Name of the LLM model used for generating responses.""" """Name of the LLM model used for generating responses."""
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
"""Maximum number of tokens allowed per LLM response.""" """Maximum number of tokens allowed per LLM response."""
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
"""Maximum number of concurrent LLM calls.""" """Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict) llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function.""" """Additional keyword arguments passed to the LLM model function."""
# Storage # Storage
# ---
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage.""" """Additional parameters for vector database storage."""
namespace_prefix: str = field(default="") namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments.""" """Prefix for namespacing stored data across different environments."""
enable_llm_cache: bool = True enable_llm_cache: bool = field(default=True)
"""Enables caching for LLM responses to avoid redundant computations.""" """Enables caching for LLM responses to avoid redundant computations."""
enable_llm_cache_for_entity_extract: bool = True enable_llm_cache_for_entity_extract: bool = field(default=True)
"""If True, enables caching for entity extraction steps to reduce LLM costs.""" """If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions # Extensions
# ---
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
"""Maximum number of parallel insert operations."""
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
# Storages Management # Storages Management
auto_manage_storages_states: bool = True # ---
auto_manage_storages_states: bool = field(default=True)
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
"""Dictionary for additional parameters and extensions.""" # Storages Management
convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( # ---
convert_response_to_json
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
default_factory=lambda: convert_response_to_json
)
"""
Custom function for converting LLM responses to JSON format.
The default function is :func:`.utils.convert_response_to_json`.
"""
cosine_better_than_threshold: float = field(
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
) )
# Custom Chunking Function _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = chunking_by_token_size
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
def __post_init__(self): def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level) logger.setLevel(self.log_level)
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
@@ -432,22 +282,16 @@ class LightRAG:
for storage_type, storage_name in storage_configs: for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility # Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name) verify_storage_implementation(storage_type, storage_name)
# Check environment variables # Check environment variables
# self.check_storage_env_vars(storage_name) # self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields # Ensure vector_db_storage_cls_kwargs has required fields
default_vector_db_kwargs = {
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
}
self.vector_db_storage_cls_kwargs = { self.vector_db_storage_cls_kwargs = {
**default_vector_db_kwargs, "cosine_better_than_threshold": self.cosine_better_than_threshold,
**self.vector_db_storage_cls_kwargs, **self.vector_db_storage_cls_kwargs,
} }
# Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config # Show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
@@ -555,7 +399,7 @@ class LightRAG:
) )
) )
self.storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
# Initialize storages # Initialize storages
if self.auto_manage_storages_states: if self.auto_manage_storages_states:
@@ -570,7 +414,7 @@ class LightRAG:
async def initialize_storages(self): async def initialize_storages(self):
"""Asynchronously initialize the storages""" """Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED: if self._storages_status == StoragesStatus.CREATED:
tasks = [] tasks = []
for storage in ( for storage in (
@@ -588,12 +432,12 @@ class LightRAG:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED self._storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages") logger.debug("Initialized Storages")
async def finalize_storages(self): async def finalize_storages(self):
"""Asynchronously finalize the storages""" """Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED: if self._storages_status == StoragesStatus.INITIALIZED:
tasks = [] tasks = []
for storage in ( for storage in (
@@ -611,7 +455,7 @@ class LightRAG:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.FINALIZED self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages") logger.debug("Finalized Storages")
async def get_graph_labels(self): async def get_graph_labels(self):
@@ -687,7 +531,7 @@ class LightRAG:
return return
update_storage = True update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs") logger.info(f"Inserting {len(new_docs)} docs")
inserting_chunks: dict[str, Any] = {} inserting_chunks: dict[str, Any] = {}
for chunk_text in text_chunks: for chunk_text in text_chunks:
@@ -780,108 +624,122 @@ class LightRAG:
4. Update the document status 4. Update the document status
""" """
# 1. Get all pending, failed, and abnormally terminated processing documents. # 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {} # Run the asynchronous status retrievals in parallel using asyncio.gather
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING) to_process_docs: dict[str, DocProcessingStatus] = {}
to_process_docs.update(processing_docs) to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
to_process_docs.update(failed_docs) to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING) to_process_docs.update(pending_docs)
to_process_docs.update(pendings_docs)
if not to_process_docs: if not to_process_docs:
logger.info("All documents have been processed or are duplicates") logger.info("All documents have been processed or are duplicates")
return return
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10)
docs_batches = [ docs_batches = [
list(to_process_docs.items())[i : i + batch_size] list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), batch_size) for i in range(0, len(to_process_docs), self.max_parallel_insert)
] ]
logger.info(f"Number of batches to process: {len(docs_batches)}.") logger.info(f"Number of batches to process: {len(docs_batches)}.")
batches: list[Any] = []
# 3. iterate over batches # 3. iterate over batches
for batch_idx, docs_batch in enumerate(docs_batches): for batch_idx, docs_batch in enumerate(docs_batches):
# 4. iterate over batch
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
# Update status in processing
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
await self.doc_status.upsert(
{
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
)
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in self.chunking_func(
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
}
# Process document (text chunks and full docs) in parallel async def batch(
tasks = [ batch_idx: int,
self.chunks_vdb.upsert(chunks), docs_batch: list[tuple[str, DocProcessingStatus]],
self._process_entity_relation_graph(chunks), size_batch: int,
self.full_docs.upsert({doc_id: {"content": status_doc.content}}), ) -> None:
self.text_chunks.upsert(chunks), logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
] # 4. iterate over batch
try: for doc_id_processing_status in docs_batch:
await asyncio.gather(*tasks) doc_id, status_doc = doc_id_processing_status
await self.doc_status.upsert( # Update status in processing
{ doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
doc_status_id: { # Generate chunks from document
"status": DocStatus.PROCESSED, chunks: dict[str, Any] = {
"chunks_count": len(chunks), compute_mdhash_id(dp["content"], prefix="chunk-"): {
"content": status_doc.content, **dp,
"content_summary": status_doc.content_summary, "full_doc_id": doc_id,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
} }
) for dp in self.chunking_func(
await self._insert_done() status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
}
# Process document (text chunks and full docs) in parallel
tasks = [
self.doc_status.upsert(
{
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
),
self.chunks_vdb.upsert(chunks),
self._process_entity_relation_graph(chunks),
self.full_docs.upsert(
{doc_id: {"content": status_doc.content}}
),
self.text_chunks.upsert(chunks),
]
try:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
{
doc_status_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert(
{
doc_status_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
continue
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
except Exception as e: batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert( await asyncio.gather(*batches)
{ await self._insert_done()
doc_status_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
continue
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try: try:
new_kg = await extract_entities( await extract_entities(
chunk, chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
@@ -889,12 +747,6 @@ class LightRAG:
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
global_config=asdict(self), global_config=asdict(self),
) )
if new_kg is None:
logger.info("No new entities or relationships extracted.")
else:
logger.info("New entities or relationships extracted.")
self.chunk_entity_relation_graph = new_kg
except Exception as e: except Exception as e:
logger.error("Failed to extract entities and relationships") logger.error("Failed to extract entities and relationships")
raise e raise e
@@ -914,6 +766,7 @@ class LightRAG:
if storage_inst is not None if storage_inst is not None
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
logger.info("All Insert done")
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None: def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
@@ -926,11 +779,28 @@ class LightRAG:
all_chunks_data: dict[str, dict[str, str]] = {} all_chunks_data: dict[str, dict[str, str]] = {}
chunk_to_source_map: dict[str, str] = {} chunk_to_source_map: dict[str, str] = {}
for chunk_data in custom_kg.get("chunks", {}): for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"] chunk_content = chunk_data["content"].strip()
source_id = chunk_data["source_id"] source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-") tokens = len(
encode_string_by_tiktoken(
chunk_content, model_name=self.tiktoken_model_name
)
)
chunk_order_index = (
0
if "chunk_order_index" not in chunk_data.keys()
else chunk_data["chunk_order_index"]
)
chunk_id = compute_mdhash_id(chunk_content, prefix="chunk-")
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id} chunk_entry = {
"content": chunk_content,
"source_id": source_id,
"tokens": tokens,
"chunk_order_index": chunk_order_index,
"full_doc_id": source_id,
"status": DocStatus.PROCESSED,
}
all_chunks_data[chunk_id] = chunk_entry all_chunks_data[chunk_id] = chunk_entry
chunk_to_source_map[source_id] = chunk_id chunk_to_source_map[source_id] = chunk_id
update_storage = True update_storage = True
@@ -1177,7 +1047,6 @@ class LightRAG:
# --------------------- # ---------------------
# STEP 1: Keyword Extraction # STEP 1: Keyword Extraction
# --------------------- # ---------------------
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
hl_keywords, ll_keywords = await extract_keywords_only( hl_keywords, ll_keywords = await extract_keywords_only(
text=query, text=query,
param=param, param=param,
@@ -1603,3 +1472,21 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None result["vector_data"] = vector_data[0] if vector_data else None
return result return result
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)

161
lightrag/llm/Readme.md Normal file
View File

@@ -0,0 +1,161 @@
1. **LlamaIndex** (`llm/llama_index.py`):
- Provides integration with OpenAI and other providers through LlamaIndex
- Supports both direct API access and proxy services like LiteLLM
- Handles embeddings and completions with consistent interfaces
- See example implementations:
- [Direct OpenAI Usage](../../examples/lightrag_llamaindex_direct_demo.py)
- [LiteLLM Proxy Usage](../../examples/lightrag_llamaindex_litellm_demo.py)
<details>
<summary> <b>Using LlamaIndex</b> </summary>
LightRAG supports LlamaIndex for embeddings and completions in two ways: direct OpenAI usage or through LiteLLM proxy.
### Setup
First, install the required dependencies:
```bash
pip install llama-index-llms-litellm llama-index-embeddings-litellm
```
### Standard OpenAI Usage
```python
from lightrag import LightRAG
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from lightrag.utils import EmbeddingFunc
# Initialize with direct OpenAI access
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize OpenAI if not in kwargs
if 'llm_instance' not in kwargs:
llm_instance = OpenAI(
model="gpt-4",
api_key="your-openai-key",
temperature=0.7,
)
kwargs['llm_instance'] = llm_instance
response = await llama_index_complete_if_cache(
kwargs['llm_instance'],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
logger.error(f"LLM request failed: {str(e)}")
raise
# Initialize LightRAG with OpenAI
rag = LightRAG(
working_dir="your/path",
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536,
max_token_size=8192,
func=lambda texts: llama_index_embed(
texts,
embed_model=OpenAIEmbedding(
model="text-embedding-3-large",
api_key="your-openai-key"
)
),
),
)
```
### Using LiteLLM Proxy
1. Use any LLM provider through LiteLLM
2. Leverage LlamaIndex's embedding and completion capabilities
3. Maintain consistent configuration across services
```python
from lightrag import LightRAG
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
from llama_index.llms.litellm import LiteLLM
from llama_index.embeddings.litellm import LiteLLMEmbedding
from lightrag.utils import EmbeddingFunc
# Initialize with LiteLLM proxy
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize LiteLLM if not in kwargs
if 'llm_instance' not in kwargs:
llm_instance = LiteLLM(
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
api_base=settings.LITELLM_URL,
api_key=settings.LITELLM_KEY,
temperature=0.7,
)
kwargs['llm_instance'] = llm_instance
response = await llama_index_complete_if_cache(
kwargs['llm_instance'],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
logger.error(f"LLM request failed: {str(e)}")
raise
# Initialize LightRAG with LiteLLM
rag = LightRAG(
working_dir="your/path",
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536,
max_token_size=8192,
func=lambda texts: llama_index_embed(
texts,
embed_model=LiteLLMEmbedding(
model_name=f"openai/{settings.EMBEDDING_MODEL}",
api_base=settings.LITELLM_URL,
api_key=settings.LITELLM_KEY,
)
),
),
)
```
### Environment Variables
For OpenAI direct usage:
```bash
OPENAI_API_KEY=your-openai-key
```
For LiteLLM proxy:
```bash
# LiteLLM Configuration
LITELLM_URL=http://litellm:4000
LITELLM_KEY=your-litellm-key
# Model Configuration
LLM_MODEL=gpt-4
EMBEDDING_MODEL=text-embedding-3-large
EMBEDDING_MAX_TOKEN_SIZE=8192
```
### Key Differences
1. **Direct OpenAI**:
- Simpler setup
- Direct API access
- Requires OpenAI API key
2. **LiteLLM Proxy**:
- Model provider agnostic
- Centralized API key management
- Support for multiple providers
- Better cost control and monitoring
</details>

View File

@@ -0,0 +1,208 @@
import pipmaster as pm
from llama_index.core.llms import (
ChatMessage,
MessageRole,
ChatResponse,
)
from typing import List, Optional
from lightrag.utils import logger
# Install required dependencies
if not pm.is_installed("llama-index"):
pm.install("llama-index")
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings as LlamaIndexSettings
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
import numpy as np
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
"""
Configure LlamaIndex settings.
Args:
settings: LlamaIndex Settings instance. If None, uses default settings.
**kwargs: Additional settings to override/configure
"""
if settings is None:
settings = LlamaIndexSettings()
# Update settings with any provided kwargs
for key, value in kwargs.items():
if hasattr(settings, key):
setattr(settings, key, value)
else:
logger.warning(f"Unknown LlamaIndex setting: {key}")
# Set as global settings
LlamaIndexSettings.set_global(settings)
return settings
def format_chat_messages(messages):
"""Format chat messages into LlamaIndex format."""
formatted_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
formatted_messages.append(
ChatMessage(role=MessageRole.SYSTEM, content=content)
)
elif role == "assistant":
formatted_messages.append(
ChatMessage(role=MessageRole.ASSISTANT, content=content)
)
elif role == "user":
formatted_messages.append(
ChatMessage(role=MessageRole.USER, content=content)
)
else:
logger.warning(f"Unknown role {role}, treating as user message")
formatted_messages.append(
ChatMessage(role=MessageRole.USER, content=content)
)
return formatted_messages
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def llama_index_complete_if_cache(
model: str,
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[dict] = [],
**kwargs,
) -> str:
"""Complete the prompt using LlamaIndex."""
try:
# Format messages for chat
formatted_messages = []
# Add system message if provided
if system_prompt:
formatted_messages.append(
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
)
# Add history messages
for msg in history_messages:
formatted_messages.append(
ChatMessage(
role=MessageRole.USER
if msg["role"] == "user"
else MessageRole.ASSISTANT,
content=msg["content"],
)
)
# Add current prompt
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
# Get LLM instance from kwargs
if "llm_instance" not in kwargs:
raise ValueError("llm_instance must be provided in kwargs")
llm = kwargs["llm_instance"]
# Get response
response: ChatResponse = await llm.achat(messages=formatted_messages)
# In newer versions, the response is in message.content
content = response.message.content
return content
except Exception as e:
logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
raise
async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
) -> str:
"""
Main completion function for LlamaIndex
Args:
prompt: Input prompt
system_prompt: Optional system prompt
history_messages: Optional chat history
keyword_extraction: Whether to extract keywords from response
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await llama_index_complete_if_cache(
kwargs.get("llm_instance"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction:
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def llama_index_embed(
texts: list[str],
embed_model: BaseEmbedding = None,
settings: LlamaIndexSettings = None,
**kwargs,
) -> np.ndarray:
"""
Generate embeddings using LlamaIndex
Args:
texts: List of texts to embed
embed_model: LlamaIndex embedding model
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if settings:
configure_llama_index(settings)
if embed_model is None:
raise ValueError("embed_model must be provided")
# Use _get_text_embeddings for batch processing
embeddings = embed_model._get_text_embeddings(texts)
return np.array(embeddings)

View File

@@ -329,7 +329,7 @@ async def extract_entities(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict[str, str], global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> BaseGraphStorage | None: ) -> None:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[ enable_llm_cache_for_entity_extract: bool = global_config[
@@ -491,11 +491,9 @@ async def extract_entities(
already_processed += 1 already_processed += 1
already_entities += len(maybe_nodes) already_entities += len(maybe_nodes)
already_relations += len(maybe_edges) already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
logger.debug( logger.debug(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", f"Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
) )
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
@@ -524,16 +522,18 @@ async def extract_entities(
] ]
) )
if not len(all_entities_data) and not len(all_relationships_data): if not (all_entities_data or all_relationships_data):
logger.warning( logger.info("Didn't extract any entities and relationships.")
"Didn't extract any entities and relationships, maybe your LLM is not working" return
)
return None
if not len(all_entities_data): if not all_entities_data:
logger.warning("Didn't extract any entities") logger.info("Didn't extract any entities")
if not len(all_relationships_data): if not all_relationships_data:
logger.warning("Didn't extract any relationships") logger.info("Didn't extract any relationships")
logger.info(
f"New entities or relationships extracted, entities:{all_entities_data}, relationships:{all_relationships_data}"
)
if entity_vdb is not None: if entity_vdb is not None:
data_for_vdb = { data_for_vdb = {
@@ -562,8 +562,6 @@ async def extract_entities(
} }
await relationships_vdb.upsert(data_for_vdb) await relationships_vdb.upsert(data_for_vdb)
return knowledge_graph_inst
async def kg_query( async def kg_query(
query: str, query: str,
@@ -1328,15 +1326,12 @@ async def _get_edge_data(
), ),
) )
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_datas = [ edge_datas = [
{ {
"src_id": k["src_id"], "src_id": k["src_id"],
"tgt_id": k["tgt_id"], "tgt_id": k["tgt_id"],
"rank": d, "rank": d,
"created_at": k.get("__created_at__", None), # 从 KV 存储中获取时间元数据 "created_at": k.get("__created_at__", None),
**v, **v,
} }
for k, v, d in zip(results, edge_datas, edge_degree) for k, v, d in zip(results, edge_datas, edge_degree)
@@ -1345,16 +1340,11 @@ async def _get_edge_data(
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
len_edge_datas = len(edge_datas)
edge_datas = truncate_list_by_token_size( edge_datas = truncate_list_by_token_size(
edge_datas, edge_datas,
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
) )
logger.debug(
f"Truncate relations from {len_edge_datas} to {len(edge_datas)} (max tokens:{query_param.max_token_for_global_context})"
)
use_entities, use_text_units = await asyncio.gather( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( _find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst edge_datas, query_param, knowledge_graph_inst

View File

@@ -9,15 +9,14 @@ PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["process_tickers"] = ["", "", "", "", "", "", "", "", "", ""]
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"] PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
PROMPTS["entity_extraction"] = """-Goal- PROMPTS["entity_extraction"] = """---Goal---
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language. Use {language} as output language.
-Steps- ---Steps---
1. Identify all entities. For each identified entity, extract the following information: 1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name. - entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
- entity_type: One of the following types: [{entity_types}] - entity_type: One of the following types: [{entity_types}]
@@ -41,18 +40,17 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_
5. When finished, output {completion_delimiter} 5. When finished, output {completion_delimiter}
###################### ######################
-Examples- ---Examples---
###################### ######################
{examples} {examples}
############################# #############################
-Real Data- ---Real Data---
###################### ######################
Entity_types: {entity_types} Entity_types: {entity_types}
Text: {input_text} Text: {input_text}
###################### ######################
Output: Output:"""
"""
PROMPTS["entity_extraction_examples"] = [ PROMPTS["entity_extraction_examples"] = [
"""Example 1: """Example 1:
@@ -137,7 +135,7 @@ Make sure it is written in third person, and include the entity names so we the
Use {language} as output language. Use {language} as output language.
####### #######
-Data- ---Data---
Entities: {entity_name} Entities: {entity_name}
Description List: {description_list} Description List: {description_list}
####### #######
@@ -205,12 +203,12 @@ Given the query and conversation history, list both high-level and low-level key
- "low_level_keywords" for specific entities or details - "low_level_keywords" for specific entities or details
###################### ######################
-Examples- ---Examples---
###################### ######################
{examples} {examples}
############################# #############################
-Real Data- ---Real Data---
###################### ######################
Conversation History: Conversation History:
{history} {history}

View File

@@ -713,3 +713,47 @@ def get_conversation_turns(
) )
return "\n".join(formatted_turns) return "\n".join(formatted_turns)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class

View File

@@ -1,7 +1,7 @@
import re import re
import json import json
import asyncio
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import always_get_an_event_loop
def extract_queries(file_path): def extract_queries(file_path):
@@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json( def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file queries, rag_instance, query_param, output_file, error_file
): ):

View File

@@ -1,10 +1,9 @@
import os import os
import re import re
import json import json
import asyncio
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
import numpy as np import numpy as np
@@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param):
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json( def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file queries, rag_instance, query_param, output_file, error_file
): ):