Merge remote-tracking branch 'origin/main' into refactor-api-server

This commit is contained in:
yangdx
2025-02-21 11:24:16 +08:00
40 changed files with 1393 additions and 592 deletions

61
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: Bug Report
description: File a bug report
title: "[Bug]: <title>"
labels: ["bug", "triage"]
body:
- type: checkboxes
id: existingcheck
attributes:
label: Do you need to file an issue?
description: Please help us manage our time by avoiding duplicates and common bugs with the steps below.
options:
- label: I have searched the existing issues and this bug is not already filed.
- label: I believe this is a legitimate bug, not just a question or feature request.
- type: textarea
id: description
attributes:
label: Describe the bug
description: A clear and concise description of what the bug is.
placeholder: What went wrong?
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: Steps to reproduce the behavior.
placeholder: How can we replicate the issue?
- type: textarea
id: expected_behavior
attributes:
label: Expected Behavior
description: A clear and concise description of what you expected to happen.
placeholder: What should have happened?
- type: textarea
id: configused
attributes:
label: LightRAG Config Used
description: The LightRAG configuration used for the run.
placeholder: The settings content or LightRAG configuration
value: |
# Paste your config here
- type: textarea
id: screenshotslogs
attributes:
label: Logs and screenshots
description: If applicable, add screenshots and logs to help explain your problem.
placeholder: Add logs and screenshots here
- type: textarea
id: additional_information
attributes:
label: Additional Information
description: |
- LightRAG Version: e.g., v0.1.1
- Operating System: e.g., Windows 10, Ubuntu 20.04
- Python Version: e.g., 3.8
- Related Issues: e.g., #1
- Any other relevant information.
value: |
- LightRAG Version:
- Operating System:
- Python Version:
- Related Issues:

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1 @@
blank_issues_enabled: false

View File

@@ -0,0 +1,26 @@
name: Feature Request
description: File a feature request
labels: ["enhancement"]
title: "[Feature Request]: <title>"
body:
- type: checkboxes
id: existingcheck
attributes:
label: Do you need to file a feature request?
description: Please help us manage our time by avoiding duplicates and common feature request with the steps below.
options:
- label: I have searched the existing feature request and this feature request is not already filed.
- label: I believe this is a legitimate feature request, not just a question or bug.
- type: textarea
id: feature_request_description
attributes:
label: Feature Request Description
description: A clear and concise description of the feature request you would like.
placeholder: What this feature request add more or improve?
- type: textarea
id: additional_context
attributes:
label: Additional Context
description: Add any other context or screenshots about the feature request here.
placeholder: Any additional information

26
.github/ISSUE_TEMPLATE/question.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Question
description: Ask a general question
labels: ["question"]
title: "[Question]: <title>"
body:
- type: checkboxes
id: existingcheck
attributes:
label: Do you need to ask a question?
description: Please help us manage our time by avoiding duplicates and common questions with the steps below.
options:
- label: I have searched the existing question and discussions and this question is not already answered.
- label: I believe this is a legitimate question, not just a bug or feature request.
- type: textarea
id: question
attributes:
label: Your Question
description: A clear and concise description of your question.
placeholder: What is your question?
- type: textarea
id: context
attributes:
label: Additional Context
description: Provide any additional context or details that might help us understand your question better.
placeholder: Add any relevant information here

32
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,32 @@
<!--
Thanks for contributing to LightRAG!
Please ensure your pull request is ready for review before submitting.
About this template
This template helps contributors provide a clear and concise description of their changes. Feel free to adjust it as needed.
-->
## Description
[Briefly describe the changes made in this pull request.]
## Related Issues
[Reference any related issues or tasks addressed by this pull request.]
## Changes Made
[List the specific changes made in this pull request.]
## Checklist
- [ ] Changes tested locally
- [ ] Code reviewed
- [ ] Documentation updated (if necessary)
- [ ] Unit tests added (if applicable)
## Additional Notes
[Add any additional notes or context for the reviewer(s).]

View File

@@ -312,7 +312,41 @@ rag = LightRAG(
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
</details>
<details>
<summary> <b>LlamaIndex</b> </summary>
LightRAG supports integration with LlamaIndex.
1. **LlamaIndex** (`llm/llama_index_impl.py`):
- Integrates with OpenAI and other providers through LlamaIndex
- See [LlamaIndex Documentation](lightrag/llm/Readme.md) for detailed setup and examples
### Example Usage
```python
# Using LlamaIndex with direct OpenAI access
from lightrag import LightRAG
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
rag = LightRAG(
working_dir="your/path",
llm_model_func=llama_index_complete_if_cache, # LlamaIndex-compatible completion function
embedding_func=EmbeddingFunc( # LlamaIndex-compatible embedding function
embedding_dim=1536,
max_token_size=8192,
func=lambda texts: llama_index_embed(texts, embed_model=embed_model)
),
)
```
#### For detailed documentation and examples, see:
- [LlamaIndex Documentation](lightrag/llm/Readme.md)
- [Direct OpenAI Example](examples/lightrag_llamaindex_direct_demo.py)
- [LiteLLM Proxy Example](examples/lightrag_llamaindex_litellm_demo.py)
</details>
<details>
<summary> <b>Conversation History Support</b> </summary>

View File

@@ -1,5 +1,3 @@
version: '3.8'
services:
lightrag:
build: .

View File

@@ -87,18 +87,27 @@ custom_kg = {
{
"content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.",
"source_id": "Source1",
"source_chunk_index": 0,
},
{
"content": "One outstanding feature of ProductX is its advanced AI capabilities.",
"source_id": "Source1",
"chunk_order_index": 1,
},
{
"content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.",
"source_id": "Source2",
"source_chunk_index": 0,
},
{
"content": "EventY, held in CityC, attracts technology enthusiasts and companies from around the globe.",
"source_id": "Source3",
"source_chunk_index": 0,
},
{
"content": "None",
"source_id": "UNKNOWN",
"source_chunk_index": 0,
},
],
}

View File

@@ -98,7 +98,6 @@ async def init():
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,

View File

@@ -0,0 +1,113 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.llama_index_impl import (
llama_index_complete_if_cache,
llama_index_embed,
)
from lightrag.utils import EmbeddingFunc
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
import asyncio
# Configure working directory
WORKING_DIR = "./index_default"
print(f"WORKING_DIR: {WORKING_DIR}")
# Model configuration
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
# OpenAI configuration
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key-here")
if not os.path.exists(WORKING_DIR):
print(f"Creating working directory: {WORKING_DIR}")
os.mkdir(WORKING_DIR)
# Initialize LLM function
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize OpenAI if not in kwargs
if "llm_instance" not in kwargs:
llm_instance = OpenAI(
model=LLM_MODEL,
api_key=OPENAI_API_KEY,
temperature=0.7,
)
kwargs["llm_instance"] = llm_instance
response = await llama_index_complete_if_cache(
kwargs["llm_instance"],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
print(f"LLM request failed: {str(e)}")
raise
# Initialize embedding function
async def embedding_func(texts):
try:
embed_model = OpenAIEmbedding(
model=EMBEDDING_MODEL,
api_key=OPENAI_API_KEY,
)
return await llama_index_embed(texts, embed_model=embed_model)
except Exception as e:
print(f"Embedding failed: {str(e)}")
raise
# Get embedding dimension
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"embedding_dim={embedding_dim}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
# Insert example text
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Test different query modes
print("\nNaive Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
print("\nLocal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
print("\nGlobal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
print("\nHybrid Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -0,0 +1,116 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.llama_index_impl import (
llama_index_complete_if_cache,
llama_index_embed,
)
from lightrag.utils import EmbeddingFunc
from llama_index.llms.litellm import LiteLLM
from llama_index.embeddings.litellm import LiteLLMEmbedding
import asyncio
# Configure working directory
WORKING_DIR = "./index_default"
print(f"WORKING_DIR: {WORKING_DIR}")
# Model configuration
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
# LiteLLM configuration
LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000")
print(f"LITELLM_URL: {LITELLM_URL}")
LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-1234")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# Initialize LLM function
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize LiteLLM if not in kwargs
if "llm_instance" not in kwargs:
llm_instance = LiteLLM(
model=f"openai/{LLM_MODEL}", # Format: "provider/model_name"
api_base=LITELLM_URL,
api_key=LITELLM_KEY,
temperature=0.7,
)
kwargs["llm_instance"] = llm_instance
response = await llama_index_complete_if_cache(
kwargs["llm_instance"],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
print(f"LLM request failed: {str(e)}")
raise
# Initialize embedding function
async def embedding_func(texts):
try:
embed_model = LiteLLMEmbedding(
model_name=f"openai/{EMBEDDING_MODEL}",
api_base=LITELLM_URL,
api_key=LITELLM_KEY,
)
return await llama_index_embed(texts, embed_model=embed_model)
except Exception as e:
print(f"Embedding failed: {str(e)}")
raise
# Get embedding dimension
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"embedding_dim={embedding_dim}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
# Insert example text
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Test different query modes
print("\nNaive Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
print("\nLocal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
print("\nGlobal Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
print("\nHybrid Search:")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,9 +1,8 @@
import os
import inspect
import os
from lightrag import LightRAG
from lightrag.llm import openai_complete, openai_embed
from lightrag.utils import EmbeddingFunc
from lightrag.lightrag import always_get_an_event_loop
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
from lightrag import QueryParam
# WorkingDir

View File

@@ -63,7 +63,6 @@ async def main():
# Initialize LightRAG
# We use TiDB DB as the KV/vector
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,

View File

@@ -70,7 +70,7 @@ def main():
),
vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": 0.3 # Your desired threshold
"cosine_better_than_threshold": 0.2 # Your desired threshold
},
)

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.7"
__version__ = "1.1.11"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1 +1,157 @@
# print ("init package vars here. ......")
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
}
def verify_storage_implementation(storage_type: str, storage_name: str) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)

View File

@@ -34,14 +34,9 @@ if not pm.is_installed("psycopg-pool"):
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
try:
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
except ImportError:
raise ImportError(
"`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`."
)
class AGEQueryException(Exception):

View File

@@ -10,13 +10,8 @@ import pipmaster as pm
if not pm.is_installed("chromadb"):
pm.install("chromadb")
try:
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
except ImportError as e:
raise ImportError(
"`chromadb` library is not installed. Please install it via pip: `pip install chromadb`."
) from e
@final
@@ -113,9 +108,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
logger.warning("Empty data provided to vector DB")
return []
return
try:
ids = list(data.keys())

View File

@@ -20,12 +20,7 @@ from lightrag.base import (
if not pm.is_installed("faiss"):
pm.install("faiss")
try:
import faiss
except ImportError as e:
raise ImportError(
"`faiss` library is not installed. Please install it via pip: `pip install faiss`."
) from e
@final
@@ -84,10 +79,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
...
}
"""
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
logger.warning("You are inserting empty data to the vector DB")
return []
return
current_time = time.time()

View File

@@ -2,6 +2,7 @@ import asyncio
import inspect
import json
import os
import pipmaster as pm
from dataclasses import dataclass
from typing import Any, Dict, List, final
@@ -20,14 +21,12 @@ from lightrag.utils import logger
from ..base import BaseGraphStorage
try:
if not pm.is_installed("gremlinpython"):
pm.install("gremlinpython")
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
except ImportError as e:
raise ImportError(
"`gremlin` library is not installed. Please install it via pip: `pip install gremlin`."
) from e
@final

View File

@@ -67,6 +67,10 @@ class JsonDocStatusStorage(DocStatusStorage):
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
self._data.update(data)
await self.index_done_callback()

View File

@@ -43,6 +43,9 @@ class JsonKVStorage(BaseKVStorage):
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)

View File

@@ -14,13 +14,8 @@ if not pm.is_installed("configparser"):
if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
try:
import configparser
from pymilvus import MilvusClient
except ImportError as e:
raise ImportError(
"`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`."
) from e
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -80,11 +75,11 @@ class MilvusVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
list_data: list[dict[str, Any]] = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},

View File

@@ -25,7 +25,6 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorDatabase,
@@ -33,10 +32,6 @@ try:
)
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
raise ImportError(
"`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`."
) from e
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -113,8 +108,12 @@ class MongoKVStorage(BaseKVStorage):
return keys - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
update_tasks = []
update_tasks: list[Any] = []
for mode, items in data.items():
for k, v in items.items():
key = f"{mode}_{k}"
@@ -186,7 +185,10 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
update_tasks = []
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
update_tasks: list[Any] = []
for k, v in data.items():
data[k]["_id"] = k
update_tasks.append(
@@ -860,10 +862,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
logger.warning("You are inserting an empty data set to vector DB")
return []
return
list_data = [
{

View File

@@ -18,12 +18,7 @@ from lightrag.base import (
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
try:
from nano_vectordb import NanoVectorDB
except ImportError as e:
raise ImportError(
"`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`."
) from e
@final
@@ -50,10 +45,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
current_time = time.time()
list_data = [

View File

@@ -23,7 +23,6 @@ import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
try:
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
@@ -31,10 +30,6 @@ try:
AsyncManagedTransaction,
GraphDatabase,
)
except ImportError as e:
raise ImportError(
"`neo4j` library is not installed. Please install it via pip: `pip install neo4j`."
) from e
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")

View File

@@ -17,16 +17,12 @@ import pipmaster as pm
if not pm.is_installed("networkx"):
pm.install("networkx")
if not pm.is_installed("graspologic"):
pm.install("graspologic")
try:
from graspologic import embed
import networkx as nx
except ImportError as e:
raise ImportError(
"`networkx` library is not installed. Please install it via pip: `pip install networkx`."
) from e
from graspologic import embed
@final

View File

@@ -26,15 +26,9 @@ if not pm.is_installed("graspologic"):
if not pm.is_installed("oracledb"):
pm.install("oracledb")
try:
from graspologic import embed
import oracledb
except ImportError as e:
raise ImportError(
"`oracledb` library is not installed. Please install it via pip: `pip install oracledb`."
) from e
class OracleDB:
def __init__(self, config, **kwargs):
@@ -51,7 +45,7 @@ class OracleDB:
self.increment = 1
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
if self.user is None or self.password is None:
raise ValueError("Missing database user or password in addon_params")
raise ValueError("Missing database user or password")
try:
oracledb.defaults.fetch_lobs = False
@@ -332,6 +326,10 @@ class OracleKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [
{

View File

@@ -38,15 +38,9 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
try:
import asyncpg
from asyncpg import Pool
except ImportError as e:
raise ImportError(
"`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`."
) from e
class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any):
@@ -61,9 +55,7 @@ class PostgreSQLDB:
self.pool: Pool | None = None
if self.user is None or self.password is None or self.database is None:
raise ValueError(
"Missing database user, password, or database in addon_params"
)
raise ValueError("Missing database user, password, or database")
async def initdb(self):
try:
@@ -353,6 +345,10 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -454,10 +450,10 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
current_time = time.time()
list_data = [
{
@@ -618,6 +614,10 @@ class PGDocStatusStorage(DocStatusStorage):
Args:
data: dictionary of document IDs and their status data
"""
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7)
on conflict(id,workspace) do update set

View File

@@ -15,17 +15,11 @@ config.read("config.ini", "utf-8")
import pipmaster as pm
if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client")
if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client")
try:
from qdrant_client import QdrantClient, models
except ImportError:
raise ImportError(
"`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`."
)
def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple"
@@ -93,9 +87,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
list_data = [
{
"id": k,

View File

@@ -49,6 +49,9 @@ class RedisKVStorage(BaseKVStorage):
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
pipe = self._redis.pipeline()
for k, v in data.items():

View File

@@ -20,14 +20,8 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
try:
from sqlalchemy import create_engine, text
except ImportError as e:
raise ImportError(
"`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`."
) from e
class TiDB:
def __init__(self, config, **kwargs):
@@ -217,6 +211,9 @@ class TiDBKVStorage(BaseKVStorage):
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -324,12 +321,12 @@ class TiDBVectorDBStorage(BaseVectorStorage):
###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
# ignore, upsert in TiDBKVStorage already
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
return []
return
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
list_data = [

View File

@@ -6,7 +6,13 @@ import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
from lightrag.kg import (
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation,
)
from .base import (
BaseGraphStorage,
@@ -32,221 +38,37 @@ from .operate import (
from .prompt import GRAPH_FIELD_SEP
from .utils import (
EmbeddingFunc,
always_get_an_event_loop,
compute_mdhash_id,
convert_response_to_json,
lazy_external_import,
limit_async_func_call,
logger,
set_logger,
encode_string_by_tiktoken,
)
from .types import KnowledgeGraph
# TODO: TO REMOVE @Yannick
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
}
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
@final
@dataclass
class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
# Directory
# ---
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
"""Directory where cache and temporary files are stored."""
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
# Storage
# ---
kv_storage: str = field(default="JsonKVStorage")
"""Storage backend for key-value data."""
@@ -261,32 +83,74 @@ class LightRAG:
"""Storage type for tracking document processing statuses."""
# Logging
current_log_level = logger.level
log_level: int = field(default=current_log_level)
# ---
log_level: int = field(default=logger.level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_dir: str = field(default=os.getcwd())
"""Directory where logs are stored. Defaults to the current working directory."""
# Text chunking
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini"
"""Model name used for tokenization when chunking text."""
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
"""Log file path."""
# Entity extraction
entity_extract_max_gleaning: int = 1
# ---
entity_extract_max_gleaning: int = field(default=1)
"""Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
entity_summary_to_max_tokens: int = field(
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
)
# Text chunking
# ---
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = field(
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
)
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text."""
"""Maximum number of tokens used for summarizing extracted entities."""
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters:
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
- `content`: The text content of the chunk.
Defaults to `chunking_by_token_size` if not specified.
"""
# Node embedding
node_embedding_algorithm: str = "node2vec"
# ---
node_embedding_algorithm: str = field(default="node2vec")
"""Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field(
@@ -308,116 +172,102 @@ class LightRAG:
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc | None = None
# Embedding
# ---
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32
embedding_batch_num: int = field(default=32)
"""Batch size for embedding computations."""
embedding_func_max_async: int = 16
embedding_func_max_async: int = field(default=16)
"""Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
# LLM Configuration
llm_model_func: Callable[..., object] | None = None
# ---
llm_model_func: Callable[..., object] | None = field(default=None)
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
llm_model_name: str = field(default="gpt-4o-mini")
"""Name of the LLM model used for generating responses."""
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
"""Maximum number of tokens allowed per LLM response."""
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
"""Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
# Storage
# ---
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments."""
enable_llm_cache: bool = True
enable_llm_cache: bool = field(default=True)
"""Enables caching for LLM responses to avoid redundant computations."""
enable_llm_cache_for_entity_extract: bool = True
enable_llm_cache_for_entity_extract: bool = field(default=True)
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions
# ---
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
"""Maximum number of parallel insert operations."""
addon_params: dict[str, Any] = field(default_factory=dict)
# Storages Management
auto_manage_storages_states: bool = True
# ---
auto_manage_storages_states: bool = field(default=True)
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
"""Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
# Storages Management
# ---
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
default_factory=lambda: convert_response_to_json
)
# Custom Chunking Function
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = chunking_by_token_size
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
Custom function for converting LLM responses to JSON format.
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
The default function is :func:`.utils.convert_response_to_json`.
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
cosine_better_than_threshold: float = field(
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
)
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level)
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
@@ -432,22 +282,16 @@ class LightRAG:
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name)
verify_storage_implementation(storage_type, storage_name)
# Check environment variables
# self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields
default_vector_db_kwargs = {
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
}
self.vector_db_storage_cls_kwargs = {
**default_vector_db_kwargs,
"cosine_better_than_threshold": self.cosine_better_than_threshold,
**self.vector_db_storage_cls_kwargs,
}
# Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
@@ -555,7 +399,7 @@ class LightRAG:
)
)
self.storages_status = StoragesStatus.CREATED
self._storages_status = StoragesStatus.CREATED
# Initialize storages
if self.auto_manage_storages_states:
@@ -570,7 +414,7 @@ class LightRAG:
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED:
if self._storages_status == StoragesStatus.CREATED:
tasks = []
for storage in (
@@ -588,12 +432,12 @@ class LightRAG:
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED
self._storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages")
async def finalize_storages(self):
"""Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED:
if self._storages_status == StoragesStatus.INITIALIZED:
tasks = []
for storage in (
@@ -611,7 +455,7 @@ class LightRAG:
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.FINALIZED
self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
async def get_graph_labels(self):
@@ -687,7 +531,7 @@ class LightRAG:
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
logger.info(f"Inserting {len(new_docs)} docs")
inserting_chunks: dict[str, Any] = {}
for chunk_text in text_chunks:
@@ -780,47 +624,45 @@ class LightRAG:
4. Update the document status
"""
# 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {}
# Run the asynchronous status retrievals in parallel using asyncio.gather
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
to_process_docs: dict[str, DocProcessingStatus] = {}
to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
to_process_docs.update(pendings_docs)
to_process_docs.update(pending_docs)
if not to_process_docs:
logger.info("All documents have been processed or are duplicates")
return
# 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10)
docs_batches = [
list(to_process_docs.items())[i : i + batch_size]
for i in range(0, len(to_process_docs), batch_size)
list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), self.max_parallel_insert)
]
logger.info(f"Number of batches to process: {len(docs_batches)}.")
batches: list[Any] = []
# 3. iterate over batches
for batch_idx, docs_batch in enumerate(docs_batches):
async def batch(
batch_idx: int,
docs_batch: list[tuple[str, DocProcessingStatus]],
size_batch: int,
) -> None:
logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
# 4. iterate over batch
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
# Update status in processing
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
await self.doc_status.upsert(
{
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
)
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
@@ -836,12 +678,25 @@ class LightRAG:
self.tiktoken_model_name,
)
}
# Process document (text chunks and full docs) in parallel
tasks = [
self.doc_status.upsert(
{
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
),
self.chunks_vdb.upsert(chunks),
self._process_entity_relation_graph(chunks),
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
self.full_docs.upsert(
{doc_id: {"content": status_doc.content}}
),
self.text_chunks.upsert(chunks),
]
try:
@@ -859,8 +714,6 @@ class LightRAG:
}
}
)
await self._insert_done()
except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert(
@@ -879,9 +732,14 @@ class LightRAG:
continue
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
await asyncio.gather(*batches)
await self._insert_done()
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try:
new_kg = await extract_entities(
await extract_entities(
chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
@@ -889,12 +747,6 @@ class LightRAG:
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if new_kg is None:
logger.info("No new entities or relationships extracted.")
else:
logger.info("New entities or relationships extracted.")
self.chunk_entity_relation_graph = new_kg
except Exception as e:
logger.error("Failed to extract entities and relationships")
raise e
@@ -914,6 +766,7 @@ class LightRAG:
if storage_inst is not None
]
await asyncio.gather(*tasks)
logger.info("All Insert done")
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
loop = always_get_an_event_loop()
@@ -926,11 +779,28 @@ class LightRAG:
all_chunks_data: dict[str, dict[str, str]] = {}
chunk_to_source_map: dict[str, str] = {}
for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"]
chunk_content = chunk_data["content"].strip()
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
tokens = len(
encode_string_by_tiktoken(
chunk_content, model_name=self.tiktoken_model_name
)
)
chunk_order_index = (
0
if "chunk_order_index" not in chunk_data.keys()
else chunk_data["chunk_order_index"]
)
chunk_id = compute_mdhash_id(chunk_content, prefix="chunk-")
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
chunk_entry = {
"content": chunk_content,
"source_id": source_id,
"tokens": tokens,
"chunk_order_index": chunk_order_index,
"full_doc_id": source_id,
"status": DocStatus.PROCESSED,
}
all_chunks_data[chunk_id] = chunk_entry
chunk_to_source_map[source_id] = chunk_id
update_storage = True
@@ -1177,7 +1047,6 @@ class LightRAG:
# ---------------------
# STEP 1: Keyword Extraction
# ---------------------
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
@@ -1603,3 +1472,21 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None
return result
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)

161
lightrag/llm/Readme.md Normal file
View File

@@ -0,0 +1,161 @@
1. **LlamaIndex** (`llm/llama_index.py`):
- Provides integration with OpenAI and other providers through LlamaIndex
- Supports both direct API access and proxy services like LiteLLM
- Handles embeddings and completions with consistent interfaces
- See example implementations:
- [Direct OpenAI Usage](../../examples/lightrag_llamaindex_direct_demo.py)
- [LiteLLM Proxy Usage](../../examples/lightrag_llamaindex_litellm_demo.py)
<details>
<summary> <b>Using LlamaIndex</b> </summary>
LightRAG supports LlamaIndex for embeddings and completions in two ways: direct OpenAI usage or through LiteLLM proxy.
### Setup
First, install the required dependencies:
```bash
pip install llama-index-llms-litellm llama-index-embeddings-litellm
```
### Standard OpenAI Usage
```python
from lightrag import LightRAG
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from lightrag.utils import EmbeddingFunc
# Initialize with direct OpenAI access
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize OpenAI if not in kwargs
if 'llm_instance' not in kwargs:
llm_instance = OpenAI(
model="gpt-4",
api_key="your-openai-key",
temperature=0.7,
)
kwargs['llm_instance'] = llm_instance
response = await llama_index_complete_if_cache(
kwargs['llm_instance'],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
logger.error(f"LLM request failed: {str(e)}")
raise
# Initialize LightRAG with OpenAI
rag = LightRAG(
working_dir="your/path",
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536,
max_token_size=8192,
func=lambda texts: llama_index_embed(
texts,
embed_model=OpenAIEmbedding(
model="text-embedding-3-large",
api_key="your-openai-key"
)
),
),
)
```
### Using LiteLLM Proxy
1. Use any LLM provider through LiteLLM
2. Leverage LlamaIndex's embedding and completion capabilities
3. Maintain consistent configuration across services
```python
from lightrag import LightRAG
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
from llama_index.llms.litellm import LiteLLM
from llama_index.embeddings.litellm import LiteLLMEmbedding
from lightrag.utils import EmbeddingFunc
# Initialize with LiteLLM proxy
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
try:
# Initialize LiteLLM if not in kwargs
if 'llm_instance' not in kwargs:
llm_instance = LiteLLM(
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
api_base=settings.LITELLM_URL,
api_key=settings.LITELLM_KEY,
temperature=0.7,
)
kwargs['llm_instance'] = llm_instance
response = await llama_index_complete_if_cache(
kwargs['llm_instance'],
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return response
except Exception as e:
logger.error(f"LLM request failed: {str(e)}")
raise
# Initialize LightRAG with LiteLLM
rag = LightRAG(
working_dir="your/path",
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536,
max_token_size=8192,
func=lambda texts: llama_index_embed(
texts,
embed_model=LiteLLMEmbedding(
model_name=f"openai/{settings.EMBEDDING_MODEL}",
api_base=settings.LITELLM_URL,
api_key=settings.LITELLM_KEY,
)
),
),
)
```
### Environment Variables
For OpenAI direct usage:
```bash
OPENAI_API_KEY=your-openai-key
```
For LiteLLM proxy:
```bash
# LiteLLM Configuration
LITELLM_URL=http://litellm:4000
LITELLM_KEY=your-litellm-key
# Model Configuration
LLM_MODEL=gpt-4
EMBEDDING_MODEL=text-embedding-3-large
EMBEDDING_MAX_TOKEN_SIZE=8192
```
### Key Differences
1. **Direct OpenAI**:
- Simpler setup
- Direct API access
- Requires OpenAI API key
2. **LiteLLM Proxy**:
- Model provider agnostic
- Centralized API key management
- Support for multiple providers
- Better cost control and monitoring
</details>

View File

@@ -0,0 +1,208 @@
import pipmaster as pm
from llama_index.core.llms import (
ChatMessage,
MessageRole,
ChatResponse,
)
from typing import List, Optional
from lightrag.utils import logger
# Install required dependencies
if not pm.is_installed("llama-index"):
pm.install("llama-index")
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings as LlamaIndexSettings
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
import numpy as np
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
"""
Configure LlamaIndex settings.
Args:
settings: LlamaIndex Settings instance. If None, uses default settings.
**kwargs: Additional settings to override/configure
"""
if settings is None:
settings = LlamaIndexSettings()
# Update settings with any provided kwargs
for key, value in kwargs.items():
if hasattr(settings, key):
setattr(settings, key, value)
else:
logger.warning(f"Unknown LlamaIndex setting: {key}")
# Set as global settings
LlamaIndexSettings.set_global(settings)
return settings
def format_chat_messages(messages):
"""Format chat messages into LlamaIndex format."""
formatted_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
formatted_messages.append(
ChatMessage(role=MessageRole.SYSTEM, content=content)
)
elif role == "assistant":
formatted_messages.append(
ChatMessage(role=MessageRole.ASSISTANT, content=content)
)
elif role == "user":
formatted_messages.append(
ChatMessage(role=MessageRole.USER, content=content)
)
else:
logger.warning(f"Unknown role {role}, treating as user message")
formatted_messages.append(
ChatMessage(role=MessageRole.USER, content=content)
)
return formatted_messages
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def llama_index_complete_if_cache(
model: str,
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[dict] = [],
**kwargs,
) -> str:
"""Complete the prompt using LlamaIndex."""
try:
# Format messages for chat
formatted_messages = []
# Add system message if provided
if system_prompt:
formatted_messages.append(
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
)
# Add history messages
for msg in history_messages:
formatted_messages.append(
ChatMessage(
role=MessageRole.USER
if msg["role"] == "user"
else MessageRole.ASSISTANT,
content=msg["content"],
)
)
# Add current prompt
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
# Get LLM instance from kwargs
if "llm_instance" not in kwargs:
raise ValueError("llm_instance must be provided in kwargs")
llm = kwargs["llm_instance"]
# Get response
response: ChatResponse = await llm.achat(messages=formatted_messages)
# In newer versions, the response is in message.content
content = response.message.content
return content
except Exception as e:
logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
raise
async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
) -> str:
"""
Main completion function for LlamaIndex
Args:
prompt: Input prompt
system_prompt: Optional system prompt
history_messages: Optional chat history
keyword_extraction: Whether to extract keywords from response
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await llama_index_complete_if_cache(
kwargs.get("llm_instance"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction:
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def llama_index_embed(
texts: list[str],
embed_model: BaseEmbedding = None,
settings: LlamaIndexSettings = None,
**kwargs,
) -> np.ndarray:
"""
Generate embeddings using LlamaIndex
Args:
texts: List of texts to embed
embed_model: LlamaIndex embedding model
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if settings:
configure_llama_index(settings)
if embed_model is None:
raise ValueError("embed_model must be provided")
# Use _get_text_embeddings for batch processing
embeddings = embed_model._get_text_embeddings(texts)
return np.array(embeddings)

View File

@@ -329,7 +329,7 @@ async def extract_entities(
relationships_vdb: BaseVectorStorage,
global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None,
) -> BaseGraphStorage | None:
) -> None:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
@@ -491,11 +491,9 @@ async def extract_entities(
already_processed += 1
already_entities += len(maybe_nodes)
already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
logger.debug(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
f"Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
)
return dict(maybe_nodes), dict(maybe_edges)
@@ -524,16 +522,18 @@ async def extract_entities(
]
)
if not len(all_entities_data) and not len(all_relationships_data):
logger.warning(
"Didn't extract any entities and relationships, maybe your LLM is not working"
)
return None
if not (all_entities_data or all_relationships_data):
logger.info("Didn't extract any entities and relationships.")
return
if not len(all_entities_data):
logger.warning("Didn't extract any entities")
if not len(all_relationships_data):
logger.warning("Didn't extract any relationships")
if not all_entities_data:
logger.info("Didn't extract any entities")
if not all_relationships_data:
logger.info("Didn't extract any relationships")
logger.info(
f"New entities or relationships extracted, entities:{all_entities_data}, relationships:{all_relationships_data}"
)
if entity_vdb is not None:
data_for_vdb = {
@@ -562,8 +562,6 @@ async def extract_entities(
}
await relationships_vdb.upsert(data_for_vdb)
return knowledge_graph_inst
async def kg_query(
query: str,
@@ -1328,15 +1326,12 @@ async def _get_edge_data(
),
)
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_datas = [
{
"src_id": k["src_id"],
"tgt_id": k["tgt_id"],
"rank": d,
"created_at": k.get("__created_at__", None), # 从 KV 存储中获取时间元数据
"created_at": k.get("__created_at__", None),
**v,
}
for k, v, d in zip(results, edge_datas, edge_degree)
@@ -1345,16 +1340,11 @@ async def _get_edge_data(
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
len_edge_datas = len(edge_datas)
edge_datas = truncate_list_by_token_size(
edge_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
logger.debug(
f"Truncate relations from {len_edge_datas} to {len(edge_datas)} (max tokens:{query_param.max_token_for_global_context})"
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst

View File

@@ -9,15 +9,14 @@ PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["process_tickers"] = ["", "", "", "", "", "", "", "", "", ""]
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
PROMPTS["entity_extraction"] = """-Goal-
PROMPTS["entity_extraction"] = """---Goal---
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language.
-Steps-
---Steps---
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
- entity_type: One of the following types: [{entity_types}]
@@ -41,18 +40,17 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_
5. When finished, output {completion_delimiter}
######################
-Examples-
---Examples---
######################
{examples}
#############################
-Real Data-
---Real Data---
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
Output:"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
@@ -137,7 +135,7 @@ Make sure it is written in third person, and include the entity names so we the
Use {language} as output language.
#######
-Data-
---Data---
Entities: {entity_name}
Description List: {description_list}
#######
@@ -205,12 +203,12 @@ Given the query and conversation history, list both high-level and low-level key
- "low_level_keywords" for specific entities or details
######################
-Examples-
---Examples---
######################
{examples}
#############################
-Real Data-
---Real Data---
######################
Conversation History:
{history}

View File

@@ -713,3 +713,47 @@ def get_conversation_turns(
)
return "\n".join(formatted_turns)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class

View File

@@ -1,7 +1,7 @@
import re
import json
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.utils import always_get_an_event_loop
def extract_queries(file_path):
@@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):

View File

@@ -1,10 +1,9 @@
import os
import re
import json
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
import numpy as np
@@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param):
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):