Merge remote-tracking branch 'origin/main' into refactor-api-server
This commit is contained in:
61
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
61
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: Bug Report
|
||||
description: File a bug report
|
||||
title: "[Bug]: <title>"
|
||||
labels: ["bug", "triage"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: existingcheck
|
||||
attributes:
|
||||
label: Do you need to file an issue?
|
||||
description: Please help us manage our time by avoiding duplicates and common bugs with the steps below.
|
||||
options:
|
||||
- label: I have searched the existing issues and this bug is not already filed.
|
||||
- label: I believe this is a legitimate bug, not just a question or feature request.
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Describe the bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
placeholder: What went wrong?
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: Steps to reproduce the behavior.
|
||||
placeholder: How can we replicate the issue?
|
||||
- type: textarea
|
||||
id: expected_behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: A clear and concise description of what you expected to happen.
|
||||
placeholder: What should have happened?
|
||||
- type: textarea
|
||||
id: configused
|
||||
attributes:
|
||||
label: LightRAG Config Used
|
||||
description: The LightRAG configuration used for the run.
|
||||
placeholder: The settings content or LightRAG configuration
|
||||
value: |
|
||||
# Paste your config here
|
||||
- type: textarea
|
||||
id: screenshotslogs
|
||||
attributes:
|
||||
label: Logs and screenshots
|
||||
description: If applicable, add screenshots and logs to help explain your problem.
|
||||
placeholder: Add logs and screenshots here
|
||||
- type: textarea
|
||||
id: additional_information
|
||||
attributes:
|
||||
label: Additional Information
|
||||
description: |
|
||||
- LightRAG Version: e.g., v0.1.1
|
||||
- Operating System: e.g., Windows 10, Ubuntu 20.04
|
||||
- Python Version: e.g., 3.8
|
||||
- Related Issues: e.g., #1
|
||||
- Any other relevant information.
|
||||
value: |
|
||||
- LightRAG Version:
|
||||
- Operating System:
|
||||
- Python Version:
|
||||
- Related Issues:
|
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
blank_issues_enabled: false
|
26
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
26
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Feature Request
|
||||
description: File a feature request
|
||||
labels: ["enhancement"]
|
||||
title: "[Feature Request]: <title>"
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: existingcheck
|
||||
attributes:
|
||||
label: Do you need to file a feature request?
|
||||
description: Please help us manage our time by avoiding duplicates and common feature request with the steps below.
|
||||
options:
|
||||
- label: I have searched the existing feature request and this feature request is not already filed.
|
||||
- label: I believe this is a legitimate feature request, not just a question or bug.
|
||||
- type: textarea
|
||||
id: feature_request_description
|
||||
attributes:
|
||||
label: Feature Request Description
|
||||
description: A clear and concise description of the feature request you would like.
|
||||
placeholder: What this feature request add more or improve?
|
||||
- type: textarea
|
||||
id: additional_context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
placeholder: Any additional information
|
26
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
26
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Question
|
||||
description: Ask a general question
|
||||
labels: ["question"]
|
||||
title: "[Question]: <title>"
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: existingcheck
|
||||
attributes:
|
||||
label: Do you need to ask a question?
|
||||
description: Please help us manage our time by avoiding duplicates and common questions with the steps below.
|
||||
options:
|
||||
- label: I have searched the existing question and discussions and this question is not already answered.
|
||||
- label: I believe this is a legitimate question, not just a bug or feature request.
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Your Question
|
||||
description: A clear and concise description of your question.
|
||||
placeholder: What is your question?
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Provide any additional context or details that might help us understand your question better.
|
||||
placeholder: Add any relevant information here
|
32
.github/pull_request_template.md
vendored
Normal file
32
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
<!--
|
||||
Thanks for contributing to LightRAG!
|
||||
|
||||
Please ensure your pull request is ready for review before submitting.
|
||||
|
||||
About this template
|
||||
|
||||
This template helps contributors provide a clear and concise description of their changes. Feel free to adjust it as needed.
|
||||
-->
|
||||
|
||||
## Description
|
||||
|
||||
[Briefly describe the changes made in this pull request.]
|
||||
|
||||
## Related Issues
|
||||
|
||||
[Reference any related issues or tasks addressed by this pull request.]
|
||||
|
||||
## Changes Made
|
||||
|
||||
[List the specific changes made in this pull request.]
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Changes tested locally
|
||||
- [ ] Code reviewed
|
||||
- [ ] Documentation updated (if necessary)
|
||||
- [ ] Unit tests added (if applicable)
|
||||
|
||||
## Additional Notes
|
||||
|
||||
[Add any additional notes or context for the reviewer(s).]
|
34
README.md
34
README.md
@@ -312,7 +312,41 @@ rag = LightRAG(
|
||||
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary> <b>LlamaIndex</b> </summary>
|
||||
|
||||
LightRAG supports integration with LlamaIndex.
|
||||
|
||||
1. **LlamaIndex** (`llm/llama_index_impl.py`):
|
||||
- Integrates with OpenAI and other providers through LlamaIndex
|
||||
- See [LlamaIndex Documentation](lightrag/llm/Readme.md) for detailed setup and examples
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
# Using LlamaIndex with direct OpenAI access
|
||||
from lightrag import LightRAG
|
||||
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir="your/path",
|
||||
llm_model_func=llama_index_complete_if_cache, # LlamaIndex-compatible completion function
|
||||
embedding_func=EmbeddingFunc( # LlamaIndex-compatible embedding function
|
||||
embedding_dim=1536,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: llama_index_embed(texts, embed_model=embed_model)
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### For detailed documentation and examples, see:
|
||||
- [LlamaIndex Documentation](lightrag/llm/Readme.md)
|
||||
- [Direct OpenAI Example](examples/lightrag_llamaindex_direct_demo.py)
|
||||
- [LiteLLM Proxy Example](examples/lightrag_llamaindex_litellm_demo.py)
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary> <b>Conversation History Support</b> </summary>
|
||||
|
||||
|
@@ -1,5 +1,3 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
lightrag:
|
||||
build: .
|
||||
|
@@ -87,18 +87,27 @@ custom_kg = {
|
||||
{
|
||||
"content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.",
|
||||
"source_id": "Source1",
|
||||
"source_chunk_index": 0,
|
||||
},
|
||||
{
|
||||
"content": "One outstanding feature of ProductX is its advanced AI capabilities.",
|
||||
"source_id": "Source1",
|
||||
"chunk_order_index": 1,
|
||||
},
|
||||
{
|
||||
"content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.",
|
||||
"source_id": "Source2",
|
||||
"source_chunk_index": 0,
|
||||
},
|
||||
{
|
||||
"content": "EventY, held in CityC, attracts technology enthusiasts and companies from around the globe.",
|
||||
"source_id": "Source3",
|
||||
"source_chunk_index": 0,
|
||||
},
|
||||
{
|
||||
"content": "None",
|
||||
"source_id": "UNKNOWN",
|
||||
"source_chunk_index": 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
@@ -98,7 +98,6 @@ async def init():
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use Oracle DB as the KV/vector/graph storage
|
||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
||||
rag = LightRAG(
|
||||
enable_llm_cache=False,
|
||||
working_dir=WORKING_DIR,
|
||||
|
113
examples/lightrag_llamaindex_direct_demo.py
Normal file
113
examples/lightrag_llamaindex_direct_demo.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.llama_index_impl import (
|
||||
llama_index_complete_if_cache,
|
||||
llama_index_embed,
|
||||
)
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
import asyncio
|
||||
|
||||
# Configure working directory
|
||||
WORKING_DIR = "./index_default"
|
||||
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||
|
||||
# Model configuration
|
||||
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
|
||||
print(f"LLM_MODEL: {LLM_MODEL}")
|
||||
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
||||
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||
|
||||
# OpenAI configuration
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key-here")
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
print(f"Creating working directory: {WORKING_DIR}")
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
# Initialize LLM function
|
||||
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
|
||||
try:
|
||||
# Initialize OpenAI if not in kwargs
|
||||
if "llm_instance" not in kwargs:
|
||||
llm_instance = OpenAI(
|
||||
model=LLM_MODEL,
|
||||
api_key=OPENAI_API_KEY,
|
||||
temperature=0.7,
|
||||
)
|
||||
kwargs["llm_instance"] = llm_instance
|
||||
|
||||
response = await llama_index_complete_if_cache(
|
||||
kwargs["llm_instance"],
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"LLM request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Initialize embedding function
|
||||
async def embedding_func(texts):
|
||||
try:
|
||||
embed_model = OpenAIEmbedding(
|
||||
model=EMBEDDING_MODEL,
|
||||
api_key=OPENAI_API_KEY,
|
||||
)
|
||||
return await llama_index_embed(texts, embed_model=embed_model)
|
||||
except Exception as e:
|
||||
print(f"Embedding failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Get embedding dimension
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
print(f"embedding_dim={embedding_dim}")
|
||||
return embedding_dim
|
||||
|
||||
|
||||
# Initialize RAG instance
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=asyncio.run(get_embedding_dim()),
|
||||
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
# Insert example text
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Test different query modes
|
||||
print("\nNaive Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
print("\nLocal Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
print("\nGlobal Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
print("\nHybrid Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
116
examples/lightrag_llamaindex_litellm_demo.py
Normal file
116
examples/lightrag_llamaindex_litellm_demo.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.llama_index_impl import (
|
||||
llama_index_complete_if_cache,
|
||||
llama_index_embed,
|
||||
)
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from llama_index.llms.litellm import LiteLLM
|
||||
from llama_index.embeddings.litellm import LiteLLMEmbedding
|
||||
import asyncio
|
||||
|
||||
# Configure working directory
|
||||
WORKING_DIR = "./index_default"
|
||||
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||
|
||||
# Model configuration
|
||||
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
|
||||
print(f"LLM_MODEL: {LLM_MODEL}")
|
||||
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
||||
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||
|
||||
# LiteLLM configuration
|
||||
LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000")
|
||||
print(f"LITELLM_URL: {LITELLM_URL}")
|
||||
LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-1234")
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
# Initialize LLM function
|
||||
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
|
||||
try:
|
||||
# Initialize LiteLLM if not in kwargs
|
||||
if "llm_instance" not in kwargs:
|
||||
llm_instance = LiteLLM(
|
||||
model=f"openai/{LLM_MODEL}", # Format: "provider/model_name"
|
||||
api_base=LITELLM_URL,
|
||||
api_key=LITELLM_KEY,
|
||||
temperature=0.7,
|
||||
)
|
||||
kwargs["llm_instance"] = llm_instance
|
||||
|
||||
response = await llama_index_complete_if_cache(
|
||||
kwargs["llm_instance"],
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"LLM request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Initialize embedding function
|
||||
async def embedding_func(texts):
|
||||
try:
|
||||
embed_model = LiteLLMEmbedding(
|
||||
model_name=f"openai/{EMBEDDING_MODEL}",
|
||||
api_base=LITELLM_URL,
|
||||
api_key=LITELLM_KEY,
|
||||
)
|
||||
return await llama_index_embed(texts, embed_model=embed_model)
|
||||
except Exception as e:
|
||||
print(f"Embedding failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Get embedding dimension
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
print(f"embedding_dim={embedding_dim}")
|
||||
return embedding_dim
|
||||
|
||||
|
||||
# Initialize RAG instance
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=asyncio.run(get_embedding_dim()),
|
||||
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
# Insert example text
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Test different query modes
|
||||
print("\nNaive Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
print("\nLocal Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
print("\nGlobal Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
print("\nHybrid Search:")
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
import inspect
|
||||
import os
|
||||
from lightrag import LightRAG
|
||||
from lightrag.llm import openai_complete, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.lightrag import always_get_an_event_loop
|
||||
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||
from lightrag import QueryParam
|
||||
|
||||
# WorkingDir
|
||||
|
@@ -63,7 +63,6 @@ async def main():
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use TiDB DB as the KV/vector
|
||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
||||
rag = LightRAG(
|
||||
enable_llm_cache=False,
|
||||
working_dir=WORKING_DIR,
|
||||
|
@@ -70,7 +70,7 @@ def main():
|
||||
),
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": 0.3 # Your desired threshold
|
||||
"cosine_better_than_threshold": 0.2 # Your desired threshold
|
||||
},
|
||||
)
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.1.7"
|
||||
__version__ = "1.1.11"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -1 +1,157 @@
|
||||
# print ("init package vars here. ......")
|
||||
STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
"GRAPH_STORAGE": {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
"VECTOR_STORAGE": {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
},
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [],
|
||||
}
|
||||
|
||||
# Storage implementation module mapping
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
}
|
||||
|
||||
|
||||
def verify_storage_implementation(storage_type: str, storage_name: str) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
@@ -34,14 +34,9 @@ if not pm.is_installed("psycopg-pool"):
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`."
|
||||
)
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
|
@@ -10,13 +10,8 @@ import pipmaster as pm
|
||||
if not pm.is_installed("chromadb"):
|
||||
pm.install("chromadb")
|
||||
|
||||
try:
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`chromadb` library is not installed. Please install it via pip: `pip install chromadb`."
|
||||
) from e
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
|
||||
|
||||
@final
|
||||
@@ -113,9 +108,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
raise
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("Empty data provided to vector DB")
|
||||
return []
|
||||
return
|
||||
|
||||
try:
|
||||
ids = list(data.keys())
|
||||
|
@@ -20,12 +20,7 @@ from lightrag.base import (
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`faiss` library is not installed. Please install it via pip: `pip install faiss`."
|
||||
) from e
|
||||
import faiss
|
||||
|
||||
|
||||
@final
|
||||
@@ -84,10 +79,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
...
|
||||
}
|
||||
"""
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You are inserting empty data to the vector DB")
|
||||
return []
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
@@ -2,6 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pipmaster as pm
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, final
|
||||
|
||||
@@ -20,14 +21,12 @@ from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
try:
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`gremlin` library is not installed. Please install it via pip: `pip install gremlin`."
|
||||
) from e
|
||||
if not pm.is_installed("gremlinpython"):
|
||||
pm.install("gremlinpython")
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
|
||||
|
||||
@final
|
||||
|
@@ -67,6 +67,10 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
|
||||
|
@@ -43,6 +43,9 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
|
||||
|
@@ -14,13 +14,8 @@ if not pm.is_installed("configparser"):
|
||||
if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
|
||||
try:
|
||||
import configparser
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`."
|
||||
) from e
|
||||
import configparser
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -80,11 +75,11 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
list_data: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
|
@@ -25,18 +25,13 @@ if not pm.is_installed("pymongo"):
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
try:
|
||||
from motor.motor_asyncio import (
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`."
|
||||
) from e
|
||||
from motor.motor_asyncio import (
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -113,8 +108,12 @@ class MongoKVStorage(BaseKVStorage):
|
||||
return keys - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
update_tasks = []
|
||||
update_tasks: list[Any] = []
|
||||
for mode, items in data.items():
|
||||
for k, v in items.items():
|
||||
key = f"{mode}_{k}"
|
||||
@@ -186,7 +185,10 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
return data - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
update_tasks = []
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
update_tasks: list[Any] = []
|
||||
for k, v in data.items():
|
||||
data[k]["_id"] = k
|
||||
update_tasks.append(
|
||||
@@ -860,10 +862,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug("vector index already exist")
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You are inserting an empty data set to vector DB")
|
||||
return []
|
||||
return
|
||||
|
||||
list_data = [
|
||||
{
|
||||
|
@@ -18,12 +18,7 @@ from lightrag.base import (
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
try:
|
||||
from nano_vectordb import NanoVectorDB
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`."
|
||||
) from e
|
||||
from nano_vectordb import NanoVectorDB
|
||||
|
||||
|
||||
@final
|
||||
@@ -50,10 +45,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
list_data = [
|
||||
|
@@ -23,18 +23,13 @@ import pipmaster as pm
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
|
||||
try:
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`neo4j` library is not installed. Please install it via pip: `pip install neo4j`."
|
||||
) from e
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
@@ -17,16 +17,12 @@ import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("networkx"):
|
||||
pm.install("networkx")
|
||||
|
||||
if not pm.is_installed("graspologic"):
|
||||
pm.install("graspologic")
|
||||
|
||||
try:
|
||||
from graspologic import embed
|
||||
import networkx as nx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`networkx` library is not installed. Please install it via pip: `pip install networkx`."
|
||||
) from e
|
||||
import networkx as nx
|
||||
from graspologic import embed
|
||||
|
||||
|
||||
@final
|
||||
|
@@ -26,14 +26,8 @@ if not pm.is_installed("graspologic"):
|
||||
if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
try:
|
||||
from graspologic import embed
|
||||
import oracledb
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`oracledb` library is not installed. Please install it via pip: `pip install oracledb`."
|
||||
) from e
|
||||
from graspologic import embed
|
||||
import oracledb
|
||||
|
||||
|
||||
class OracleDB:
|
||||
@@ -51,7 +45,7 @@ class OracleDB:
|
||||
self.increment = 1
|
||||
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
||||
if self.user is None or self.password is None:
|
||||
raise ValueError("Missing database user or password in addon_params")
|
||||
raise ValueError("Missing database user or password")
|
||||
|
||||
try:
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
@@ -332,6 +326,10 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
list_data = [
|
||||
{
|
||||
|
@@ -38,14 +38,8 @@ import pipmaster as pm
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
from asyncpg import Pool
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`."
|
||||
) from e
|
||||
import asyncpg
|
||||
from asyncpg import Pool
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
@@ -61,9 +55,7 @@ class PostgreSQLDB:
|
||||
self.pool: Pool | None = None
|
||||
|
||||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError(
|
||||
"Missing database user, password, or database in addon_params"
|
||||
)
|
||||
raise ValueError("Missing database user, password, or database")
|
||||
|
||||
async def initdb(self):
|
||||
try:
|
||||
@@ -353,6 +345,10 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
pass
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||
@@ -454,10 +450,10 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
return upsert_sql, data
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
list_data = [
|
||||
{
|
||||
@@ -618,6 +614,10 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
Args:
|
||||
data: dictionary of document IDs and their status data
|
||||
"""
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
|
||||
values($1,$2,$3,$4,$5,$6,$7)
|
||||
on conflict(id,workspace) do update set
|
||||
|
@@ -15,16 +15,10 @@ config.read("config.ini", "utf-8")
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("qdrant_client"):
|
||||
pm.install("qdrant_client")
|
||||
if not pm.is_installed("qdrant-client"):
|
||||
pm.install("qdrant-client")
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`."
|
||||
)
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
@@ -93,9 +87,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
|
@@ -49,6 +49,9 @@ class RedisKVStorage(BaseKVStorage):
|
||||
return set(keys) - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
pipe = self._redis.pipeline()
|
||||
|
||||
for k, v in data.items():
|
||||
|
@@ -20,13 +20,7 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`."
|
||||
) from e
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
|
||||
class TiDB:
|
||||
@@ -217,6 +211,9 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
@@ -324,12 +321,12 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
###### INSERT entities And relationships ######
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
# ignore, upsert in TiDBKVStorage already
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
|
||||
return []
|
||||
return
|
||||
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
|
||||
list_data = [
|
||||
|
@@ -6,7 +6,13 @@ import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
|
||||
|
||||
from lightrag.kg import (
|
||||
STORAGE_ENV_REQUIREMENTS,
|
||||
STORAGES,
|
||||
verify_storage_implementation,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -32,221 +38,37 @@ from .operate import (
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
always_get_an_event_loop,
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
logger,
|
||||
set_logger,
|
||||
encode_string_by_tiktoken,
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
|
||||
# TODO: TO REMOVE @Yannick
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
# Storage type and implementation compatibility validation table
|
||||
STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
"GRAPH_STORAGE": {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
"VECTOR_STORAGE": {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
},
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [],
|
||||
}
|
||||
|
||||
# Storage implementation module mapping
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
}
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
|
||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||
it creates a new event loop and sets it as the current event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||
"""
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
except RuntimeError:
|
||||
# If no event loop exists or it is closed, create a new one
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class LightRAG:
|
||||
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
||||
|
||||
# Directory
|
||||
# ---
|
||||
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
"""Directory where cache and temporary files are stored."""
|
||||
|
||||
embedding_cache_config: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
}
|
||||
)
|
||||
"""Configuration for embedding cache.
|
||||
- enabled: If True, enables caching to avoid redundant computations.
|
||||
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
||||
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||
"""
|
||||
# Storage
|
||||
# ---
|
||||
|
||||
kv_storage: str = field(default="JsonKVStorage")
|
||||
"""Storage backend for key-value data."""
|
||||
@@ -261,32 +83,74 @@ class LightRAG:
|
||||
"""Storage type for tracking document processing statuses."""
|
||||
|
||||
# Logging
|
||||
current_log_level = logger.level
|
||||
log_level: int = field(default=current_log_level)
|
||||
# ---
|
||||
|
||||
log_level: int = field(default=logger.level)
|
||||
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
||||
|
||||
log_dir: str = field(default=os.getcwd())
|
||||
"""Directory where logs are stored. Defaults to the current working directory."""
|
||||
|
||||
# Text chunking
|
||||
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
|
||||
"""Maximum number of tokens per text chunk when splitting documents."""
|
||||
|
||||
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
|
||||
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||
|
||||
tiktoken_model_name: str = "gpt-4o-mini"
|
||||
"""Model name used for tokenization when chunking text."""
|
||||
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
|
||||
"""Log file path."""
|
||||
|
||||
# Entity extraction
|
||||
entity_extract_max_gleaning: int = 1
|
||||
# ---
|
||||
|
||||
entity_extract_max_gleaning: int = field(default=1)
|
||||
"""Maximum number of entity extraction attempts for ambiguous content."""
|
||||
|
||||
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
|
||||
entity_summary_to_max_tokens: int = field(
|
||||
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
|
||||
)
|
||||
|
||||
# Text chunking
|
||||
# ---
|
||||
|
||||
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
|
||||
"""Maximum number of tokens per text chunk when splitting documents."""
|
||||
|
||||
chunk_overlap_token_size: int = field(
|
||||
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
|
||||
)
|
||||
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||
|
||||
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
||||
"""Model name used for tokenization when chunking text."""
|
||||
|
||||
"""Maximum number of tokens used for summarizing extracted entities."""
|
||||
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
str,
|
||||
],
|
||||
list[dict[str, Any]],
|
||||
] = field(default_factory=lambda: chunking_by_token_size)
|
||||
"""
|
||||
Custom chunking function for splitting text into chunks before processing.
|
||||
|
||||
The function should take the following parameters:
|
||||
|
||||
- `content`: The text to be split into chunks.
|
||||
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
||||
- `split_by_character_only`: If True, the text is split only on the specified character.
|
||||
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
||||
|
||||
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
||||
- `tokens`: The number of tokens in the chunk.
|
||||
- `content`: The text content of the chunk.
|
||||
|
||||
Defaults to `chunking_by_token_size` if not specified.
|
||||
"""
|
||||
|
||||
# Node embedding
|
||||
node_embedding_algorithm: str = "node2vec"
|
||||
# ---
|
||||
|
||||
node_embedding_algorithm: str = field(default="node2vec")
|
||||
"""Algorithm used for node embedding in knowledge graphs."""
|
||||
|
||||
node2vec_params: dict[str, int] = field(
|
||||
@@ -308,116 +172,102 @@ class LightRAG:
|
||||
- random_seed: Seed value for reproducibility.
|
||||
"""
|
||||
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
# Embedding
|
||||
# ---
|
||||
|
||||
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = 32
|
||||
embedding_batch_num: int = field(default=32)
|
||||
"""Batch size for embedding computations."""
|
||||
|
||||
embedding_func_max_async: int = 16
|
||||
embedding_func_max_async: int = field(default=16)
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
embedding_cache_config: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
}
|
||||
)
|
||||
"""Configuration for embedding cache.
|
||||
- enabled: If True, enables caching to avoid redundant computations.
|
||||
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
||||
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||
"""
|
||||
|
||||
# LLM Configuration
|
||||
llm_model_func: Callable[..., object] | None = None
|
||||
# ---
|
||||
|
||||
llm_model_func: Callable[..., object] | None = field(default=None)
|
||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
llm_model_name: str = field(default="gpt-4o-mini")
|
||||
"""Name of the LLM model used for generating responses."""
|
||||
|
||||
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
|
||||
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
|
||||
"""Maximum number of tokens allowed per LLM response."""
|
||||
|
||||
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
|
||||
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
|
||||
"""Maximum number of concurrent LLM calls."""
|
||||
|
||||
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional keyword arguments passed to the LLM model function."""
|
||||
|
||||
# Storage
|
||||
# ---
|
||||
|
||||
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional parameters for vector database storage."""
|
||||
|
||||
namespace_prefix: str = field(default="")
|
||||
"""Prefix for namespacing stored data across different environments."""
|
||||
|
||||
enable_llm_cache: bool = True
|
||||
enable_llm_cache: bool = field(default=True)
|
||||
"""Enables caching for LLM responses to avoid redundant computations."""
|
||||
|
||||
enable_llm_cache_for_entity_extract: bool = True
|
||||
enable_llm_cache_for_entity_extract: bool = field(default=True)
|
||||
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
||||
|
||||
# Extensions
|
||||
# ---
|
||||
|
||||
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
|
||||
"""Maximum number of parallel insert operations."""
|
||||
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Storages Management
|
||||
auto_manage_storages_states: bool = True
|
||||
# ---
|
||||
|
||||
auto_manage_storages_states: bool = field(default=True)
|
||||
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
||||
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||
convert_response_to_json
|
||||
# Storages Management
|
||||
# ---
|
||||
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
|
||||
default_factory=lambda: convert_response_to_json
|
||||
)
|
||||
"""
|
||||
Custom function for converting LLM responses to JSON format.
|
||||
|
||||
The default function is :func:`.utils.convert_response_to_json`.
|
||||
"""
|
||||
|
||||
cosine_better_than_threshold: float = field(
|
||||
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
||||
)
|
||||
|
||||
# Custom Chunking Function
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
str,
|
||||
],
|
||||
list[dict[str, Any]],
|
||||
] = chunking_by_token_size
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
|
||||
logger.setLevel(self.log_level)
|
||||
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
||||
set_logger(self.log_file_path)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
@@ -432,22 +282,16 @@ class LightRAG:
|
||||
|
||||
for storage_type, storage_name in storage_configs:
|
||||
# Verify storage implementation compatibility
|
||||
self.verify_storage_implementation(storage_type, storage_name)
|
||||
verify_storage_implementation(storage_type, storage_name)
|
||||
# Check environment variables
|
||||
# self.check_storage_env_vars(storage_name)
|
||||
|
||||
# Ensure vector_db_storage_cls_kwargs has required fields
|
||||
default_vector_db_kwargs = {
|
||||
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
}
|
||||
self.vector_db_storage_cls_kwargs = {
|
||||
**default_vector_db_kwargs,
|
||||
"cosine_better_than_threshold": self.cosine_better_than_threshold,
|
||||
**self.vector_db_storage_cls_kwargs,
|
||||
}
|
||||
|
||||
# Life cycle
|
||||
self.storages_status = StoragesStatus.NOT_CREATED
|
||||
|
||||
# Show config
|
||||
global_config = asdict(self)
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||
@@ -555,7 +399,7 @@ class LightRAG:
|
||||
)
|
||||
)
|
||||
|
||||
self.storages_status = StoragesStatus.CREATED
|
||||
self._storages_status = StoragesStatus.CREATED
|
||||
|
||||
# Initialize storages
|
||||
if self.auto_manage_storages_states:
|
||||
@@ -570,7 +414,7 @@ class LightRAG:
|
||||
|
||||
async def initialize_storages(self):
|
||||
"""Asynchronously initialize the storages"""
|
||||
if self.storages_status == StoragesStatus.CREATED:
|
||||
if self._storages_status == StoragesStatus.CREATED:
|
||||
tasks = []
|
||||
|
||||
for storage in (
|
||||
@@ -588,12 +432,12 @@ class LightRAG:
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.storages_status = StoragesStatus.INITIALIZED
|
||||
self._storages_status = StoragesStatus.INITIALIZED
|
||||
logger.debug("Initialized Storages")
|
||||
|
||||
async def finalize_storages(self):
|
||||
"""Asynchronously finalize the storages"""
|
||||
if self.storages_status == StoragesStatus.INITIALIZED:
|
||||
if self._storages_status == StoragesStatus.INITIALIZED:
|
||||
tasks = []
|
||||
|
||||
for storage in (
|
||||
@@ -611,7 +455,7 @@ class LightRAG:
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.storages_status = StoragesStatus.FINALIZED
|
||||
self._storages_status = StoragesStatus.FINALIZED
|
||||
logger.debug("Finalized Storages")
|
||||
|
||||
async def get_graph_labels(self):
|
||||
@@ -687,7 +531,7 @@ class LightRAG:
|
||||
return
|
||||
|
||||
update_storage = True
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
logger.info(f"Inserting {len(new_docs)} docs")
|
||||
|
||||
inserting_chunks: dict[str, Any] = {}
|
||||
for chunk_text in text_chunks:
|
||||
@@ -780,108 +624,122 @@ class LightRAG:
|
||||
4. Update the document status
|
||||
"""
|
||||
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||
# Run the asynchronous status retrievals in parallel using asyncio.gather
|
||||
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
||||
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
||||
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
||||
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
||||
)
|
||||
|
||||
processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
|
||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||
to_process_docs.update(processing_docs)
|
||||
failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
|
||||
to_process_docs.update(failed_docs)
|
||||
pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
|
||||
to_process_docs.update(pendings_docs)
|
||||
to_process_docs.update(pending_docs)
|
||||
|
||||
if not to_process_docs:
|
||||
logger.info("All documents have been processed or are duplicates")
|
||||
return
|
||||
|
||||
# 2. split docs into chunks, insert chunks, update doc status
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
docs_batches = [
|
||||
list(to_process_docs.items())[i : i + batch_size]
|
||||
for i in range(0, len(to_process_docs), batch_size)
|
||||
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
||||
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
||||
]
|
||||
|
||||
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
||||
|
||||
batches: list[Any] = []
|
||||
# 3. iterate over batches
|
||||
for batch_idx, docs_batch in enumerate(docs_batches):
|
||||
# 4. iterate over batch
|
||||
for doc_id_processing_status in docs_batch:
|
||||
doc_id, status_doc = doc_id_processing_status
|
||||
# Update status in processing
|
||||
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
|
||||
# Process document (text chunks and full docs) in parallel
|
||||
tasks = [
|
||||
self.chunks_vdb.upsert(chunks),
|
||||
self._process_entity_relation_graph(chunks),
|
||||
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
|
||||
self.text_chunks.upsert(chunks),
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
async def batch(
|
||||
batch_idx: int,
|
||||
docs_batch: list[tuple[str, DocProcessingStatus]],
|
||||
size_batch: int,
|
||||
) -> None:
|
||||
logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
|
||||
# 4. iterate over batch
|
||||
for doc_id_processing_status in docs_batch:
|
||||
doc_id, status_doc = doc_id_processing_status
|
||||
# Update status in processing
|
||||
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
)
|
||||
await self._insert_done()
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
# Process document (text chunks and full docs) in parallel
|
||||
tasks = [
|
||||
self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
}
|
||||
}
|
||||
),
|
||||
self.chunks_vdb.upsert(chunks),
|
||||
self._process_entity_relation_graph(chunks),
|
||||
self.full_docs.upsert(
|
||||
{doc_id: {"content": status_doc.content}}
|
||||
),
|
||||
self.text_chunks.upsert(chunks),
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
|
||||
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
|
||||
|
||||
await asyncio.gather(*batches)
|
||||
await self._insert_done()
|
||||
|
||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||
try:
|
||||
new_kg = await extract_entities(
|
||||
await extract_entities(
|
||||
chunk,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
@@ -889,12 +747,6 @@ class LightRAG:
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
if new_kg is None:
|
||||
logger.info("No new entities or relationships extracted.")
|
||||
else:
|
||||
logger.info("New entities or relationships extracted.")
|
||||
self.chunk_entity_relation_graph = new_kg
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract entities and relationships")
|
||||
raise e
|
||||
@@ -914,6 +766,7 @@ class LightRAG:
|
||||
if storage_inst is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info("All Insert done")
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
||||
loop = always_get_an_event_loop()
|
||||
@@ -926,11 +779,28 @@ class LightRAG:
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", {}):
|
||||
chunk_content = chunk_data["content"]
|
||||
chunk_content = chunk_data["content"].strip()
|
||||
source_id = chunk_data["source_id"]
|
||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||
tokens = len(
|
||||
encode_string_by_tiktoken(
|
||||
chunk_content, model_name=self.tiktoken_model_name
|
||||
)
|
||||
)
|
||||
chunk_order_index = (
|
||||
0
|
||||
if "chunk_order_index" not in chunk_data.keys()
|
||||
else chunk_data["chunk_order_index"]
|
||||
)
|
||||
chunk_id = compute_mdhash_id(chunk_content, prefix="chunk-")
|
||||
|
||||
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
|
||||
chunk_entry = {
|
||||
"content": chunk_content,
|
||||
"source_id": source_id,
|
||||
"tokens": tokens,
|
||||
"chunk_order_index": chunk_order_index,
|
||||
"full_doc_id": source_id,
|
||||
"status": DocStatus.PROCESSED,
|
||||
}
|
||||
all_chunks_data[chunk_id] = chunk_entry
|
||||
chunk_to_source_map[source_id] = chunk_id
|
||||
update_storage = True
|
||||
@@ -1177,7 +1047,6 @@ class LightRAG:
|
||||
# ---------------------
|
||||
# STEP 1: Keyword Extraction
|
||||
# ---------------------
|
||||
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
@@ -1603,3 +1472,21 @@ class LightRAG:
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
|
||||
return result
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
161
lightrag/llm/Readme.md
Normal file
161
lightrag/llm/Readme.md
Normal file
@@ -0,0 +1,161 @@
|
||||
|
||||
1. **LlamaIndex** (`llm/llama_index.py`):
|
||||
- Provides integration with OpenAI and other providers through LlamaIndex
|
||||
- Supports both direct API access and proxy services like LiteLLM
|
||||
- Handles embeddings and completions with consistent interfaces
|
||||
- See example implementations:
|
||||
- [Direct OpenAI Usage](../../examples/lightrag_llamaindex_direct_demo.py)
|
||||
- [LiteLLM Proxy Usage](../../examples/lightrag_llamaindex_litellm_demo.py)
|
||||
|
||||
<details>
|
||||
<summary> <b>Using LlamaIndex</b> </summary>
|
||||
|
||||
LightRAG supports LlamaIndex for embeddings and completions in two ways: direct OpenAI usage or through LiteLLM proxy.
|
||||
|
||||
### Setup
|
||||
|
||||
First, install the required dependencies:
|
||||
```bash
|
||||
pip install llama-index-llms-litellm llama-index-embeddings-litellm
|
||||
```
|
||||
|
||||
### Standard OpenAI Usage
|
||||
|
||||
```python
|
||||
from lightrag import LightRAG
|
||||
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
|
||||
# Initialize with direct OpenAI access
|
||||
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
|
||||
try:
|
||||
# Initialize OpenAI if not in kwargs
|
||||
if 'llm_instance' not in kwargs:
|
||||
llm_instance = OpenAI(
|
||||
model="gpt-4",
|
||||
api_key="your-openai-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
kwargs['llm_instance'] = llm_instance
|
||||
|
||||
response = await llama_index_complete_if_cache(
|
||||
kwargs['llm_instance'],
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"LLM request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
# Initialize LightRAG with OpenAI
|
||||
rag = LightRAG(
|
||||
working_dir="your/path",
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: llama_index_embed(
|
||||
texts,
|
||||
embed_model=OpenAIEmbedding(
|
||||
model="text-embedding-3-large",
|
||||
api_key="your-openai-key"
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Using LiteLLM Proxy
|
||||
|
||||
1. Use any LLM provider through LiteLLM
|
||||
2. Leverage LlamaIndex's embedding and completion capabilities
|
||||
3. Maintain consistent configuration across services
|
||||
|
||||
```python
|
||||
from lightrag import LightRAG
|
||||
from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed
|
||||
from llama_index.llms.litellm import LiteLLM
|
||||
from llama_index.embeddings.litellm import LiteLLMEmbedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
|
||||
# Initialize with LiteLLM proxy
|
||||
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
|
||||
try:
|
||||
# Initialize LiteLLM if not in kwargs
|
||||
if 'llm_instance' not in kwargs:
|
||||
llm_instance = LiteLLM(
|
||||
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
|
||||
api_base=settings.LITELLM_URL,
|
||||
api_key=settings.LITELLM_KEY,
|
||||
temperature=0.7,
|
||||
)
|
||||
kwargs['llm_instance'] = llm_instance
|
||||
|
||||
response = await llama_index_complete_if_cache(
|
||||
kwargs['llm_instance'],
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"LLM request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
# Initialize LightRAG with LiteLLM
|
||||
rag = LightRAG(
|
||||
working_dir="your/path",
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: llama_index_embed(
|
||||
texts,
|
||||
embed_model=LiteLLMEmbedding(
|
||||
model_name=f"openai/{settings.EMBEDDING_MODEL}",
|
||||
api_base=settings.LITELLM_URL,
|
||||
api_key=settings.LITELLM_KEY,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
For OpenAI direct usage:
|
||||
```bash
|
||||
OPENAI_API_KEY=your-openai-key
|
||||
```
|
||||
|
||||
For LiteLLM proxy:
|
||||
```bash
|
||||
# LiteLLM Configuration
|
||||
LITELLM_URL=http://litellm:4000
|
||||
LITELLM_KEY=your-litellm-key
|
||||
|
||||
# Model Configuration
|
||||
LLM_MODEL=gpt-4
|
||||
EMBEDDING_MODEL=text-embedding-3-large
|
||||
EMBEDDING_MAX_TOKEN_SIZE=8192
|
||||
```
|
||||
|
||||
### Key Differences
|
||||
1. **Direct OpenAI**:
|
||||
- Simpler setup
|
||||
- Direct API access
|
||||
- Requires OpenAI API key
|
||||
|
||||
2. **LiteLLM Proxy**:
|
||||
- Model provider agnostic
|
||||
- Centralized API key management
|
||||
- Support for multiple providers
|
||||
- Better cost control and monitoring
|
||||
|
||||
</details>
|
208
lightrag/llm/llama_index_impl.py
Normal file
208
lightrag/llm/llama_index_impl.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import pipmaster as pm
|
||||
from llama_index.core.llms import (
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
ChatResponse,
|
||||
)
|
||||
from typing import List, Optional
|
||||
from lightrag.utils import logger
|
||||
|
||||
# Install required dependencies
|
||||
if not pm.is_installed("llama-index"):
|
||||
pm.install("llama-index")
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.settings import Settings as LlamaIndexSettings
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
|
||||
"""
|
||||
Configure LlamaIndex settings.
|
||||
|
||||
Args:
|
||||
settings: LlamaIndex Settings instance. If None, uses default settings.
|
||||
**kwargs: Additional settings to override/configure
|
||||
"""
|
||||
if settings is None:
|
||||
settings = LlamaIndexSettings()
|
||||
|
||||
# Update settings with any provided kwargs
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
else:
|
||||
logger.warning(f"Unknown LlamaIndex setting: {key}")
|
||||
|
||||
# Set as global settings
|
||||
LlamaIndexSettings.set_global(settings)
|
||||
return settings
|
||||
|
||||
|
||||
def format_chat_messages(messages):
|
||||
"""Format chat messages into LlamaIndex format."""
|
||||
formatted_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=content)
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
)
|
||||
elif role == "user":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.USER, content=content)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unknown role {role}, treating as user message")
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.USER, content=content)
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def llama_index_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[dict] = [],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Complete the prompt using LlamaIndex."""
|
||||
try:
|
||||
# Format messages for chat
|
||||
formatted_messages = []
|
||||
|
||||
# Add system message if provided
|
||||
if system_prompt:
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
|
||||
)
|
||||
|
||||
# Add history messages
|
||||
for msg in history_messages:
|
||||
formatted_messages.append(
|
||||
ChatMessage(
|
||||
role=MessageRole.USER
|
||||
if msg["role"] == "user"
|
||||
else MessageRole.ASSISTANT,
|
||||
content=msg["content"],
|
||||
)
|
||||
)
|
||||
|
||||
# Add current prompt
|
||||
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
|
||||
|
||||
# Get LLM instance from kwargs
|
||||
if "llm_instance" not in kwargs:
|
||||
raise ValueError("llm_instance must be provided in kwargs")
|
||||
llm = kwargs["llm_instance"]
|
||||
|
||||
# Get response
|
||||
response: ChatResponse = await llm.achat(messages=formatted_messages)
|
||||
|
||||
# In newer versions, the response is in message.content
|
||||
content = response.message.content
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def llama_index_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
keyword_extraction=False,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Main completion function for LlamaIndex
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
system_prompt: Optional system prompt
|
||||
history_messages: Optional chat history
|
||||
keyword_extraction: Whether to extract keywords from response
|
||||
settings: Optional LlamaIndex settings
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
result = await llama_index_complete_if_cache(
|
||||
kwargs.get("llm_instance"),
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction:
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def llama_index_embed(
|
||||
texts: list[str],
|
||||
embed_model: BaseEmbedding = None,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings using LlamaIndex
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
embed_model: LlamaIndex embedding model
|
||||
settings: Optional LlamaIndex settings
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
if settings:
|
||||
configure_llama_index(settings)
|
||||
|
||||
if embed_model is None:
|
||||
raise ValueError("embed_model must be provided")
|
||||
|
||||
# Use _get_text_embeddings for batch processing
|
||||
embeddings = embed_model._get_text_embeddings(texts)
|
||||
return np.array(embeddings)
|
@@ -329,7 +329,7 @@ async def extract_entities(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict[str, str],
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> BaseGraphStorage | None:
|
||||
) -> None:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
@@ -491,11 +491,9 @@ async def extract_entities(
|
||||
already_processed += 1
|
||||
already_entities += len(maybe_nodes)
|
||||
already_relations += len(maybe_edges)
|
||||
now_ticks = PROMPTS["process_tickers"][
|
||||
already_processed % len(PROMPTS["process_tickers"])
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||
f"Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||
)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
@@ -524,16 +522,18 @@ async def extract_entities(
|
||||
]
|
||||
)
|
||||
|
||||
if not len(all_entities_data) and not len(all_relationships_data):
|
||||
logger.warning(
|
||||
"Didn't extract any entities and relationships, maybe your LLM is not working"
|
||||
)
|
||||
return None
|
||||
if not (all_entities_data or all_relationships_data):
|
||||
logger.info("Didn't extract any entities and relationships.")
|
||||
return
|
||||
|
||||
if not len(all_entities_data):
|
||||
logger.warning("Didn't extract any entities")
|
||||
if not len(all_relationships_data):
|
||||
logger.warning("Didn't extract any relationships")
|
||||
if not all_entities_data:
|
||||
logger.info("Didn't extract any entities")
|
||||
if not all_relationships_data:
|
||||
logger.info("Didn't extract any relationships")
|
||||
|
||||
logger.info(
|
||||
f"New entities or relationships extracted, entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||||
)
|
||||
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
@@ -562,8 +562,6 @@ async def extract_entities(
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
return knowledge_graph_inst
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query: str,
|
||||
@@ -1328,15 +1326,12 @@ async def _get_edge_data(
|
||||
),
|
||||
)
|
||||
|
||||
if not all([n is not None for n in edge_datas]):
|
||||
logger.warning("Some edges are missing, maybe the storage is damaged")
|
||||
|
||||
edge_datas = [
|
||||
{
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": d,
|
||||
"created_at": k.get("__created_at__", None), # 从 KV 存储中获取时间元数据
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**v,
|
||||
}
|
||||
for k, v, d in zip(results, edge_datas, edge_degree)
|
||||
@@ -1345,16 +1340,11 @@ async def _get_edge_data(
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
len_edge_datas = len(edge_datas)
|
||||
edge_datas = truncate_list_by_token_size(
|
||||
edge_datas,
|
||||
key=lambda x: x["description"],
|
||||
max_token_size=query_param.max_token_for_global_context,
|
||||
)
|
||||
logger.debug(
|
||||
f"Truncate relations from {len_edge_datas} to {len(edge_datas)} (max tokens:{query_param.max_token_for_global_context})"
|
||||
)
|
||||
|
||||
use_entities, use_text_units = await asyncio.gather(
|
||||
_find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst
|
||||
|
@@ -9,15 +9,14 @@ PROMPTS["DEFAULT_LANGUAGE"] = "English"
|
||||
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
|
||||
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
|
||||
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
|
||||
PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
||||
|
||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
|
||||
|
||||
PROMPTS["entity_extraction"] = """-Goal-
|
||||
PROMPTS["entity_extraction"] = """---Goal---
|
||||
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
||||
Use {language} as output language.
|
||||
|
||||
-Steps-
|
||||
---Steps---
|
||||
1. Identify all entities. For each identified entity, extract the following information:
|
||||
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
|
||||
- entity_type: One of the following types: [{entity_types}]
|
||||
@@ -41,18 +40,17 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_
|
||||
5. When finished, output {completion_delimiter}
|
||||
|
||||
######################
|
||||
-Examples-
|
||||
---Examples---
|
||||
######################
|
||||
{examples}
|
||||
|
||||
#############################
|
||||
-Real Data-
|
||||
---Real Data---
|
||||
######################
|
||||
Entity_types: {entity_types}
|
||||
Text: {input_text}
|
||||
######################
|
||||
Output:
|
||||
"""
|
||||
Output:"""
|
||||
|
||||
PROMPTS["entity_extraction_examples"] = [
|
||||
"""Example 1:
|
||||
@@ -137,7 +135,7 @@ Make sure it is written in third person, and include the entity names so we the
|
||||
Use {language} as output language.
|
||||
|
||||
#######
|
||||
-Data-
|
||||
---Data---
|
||||
Entities: {entity_name}
|
||||
Description List: {description_list}
|
||||
#######
|
||||
@@ -205,12 +203,12 @@ Given the query and conversation history, list both high-level and low-level key
|
||||
- "low_level_keywords" for specific entities or details
|
||||
|
||||
######################
|
||||
-Examples-
|
||||
---Examples---
|
||||
######################
|
||||
{examples}
|
||||
|
||||
#############################
|
||||
-Real Data-
|
||||
---Real Data---
|
||||
######################
|
||||
Conversation History:
|
||||
{history}
|
||||
|
@@ -713,3 +713,47 @@ def get_conversation_turns(
|
||||
)
|
||||
|
||||
return "\n".join(formatted_turns)
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
|
||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||
it creates a new event loop and sets it as the current event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||
"""
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
except RuntimeError:
|
||||
# If no event loop exists or it is closed, create a new one
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import always_get_an_event_loop
|
||||
|
||||
|
||||
def extract_queries(file_path):
|
||||
@@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def run_queries_and_save_to_json(
|
||||
queries, rag_instance, query_param, output_file, error_file
|
||||
):
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param):
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def run_queries_and_save_to_json(
|
||||
queries, rag_instance, query_param, output_file, error_file
|
||||
):
|
||||
|
Reference in New Issue
Block a user