Merge branch 'main' into select-datastore-in-api-server

This commit is contained in:
yangdx
2025-02-12 09:49:18 +08:00
11 changed files with 320 additions and 59 deletions

View File

@@ -85,7 +85,7 @@ Use the below Python snippet (in a script) to initialize LightRAG and perform qu
```python
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
@@ -95,12 +95,12 @@ from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
@@ -355,16 +355,26 @@ In order to run this experiment on low RAM GPU you should select small model and
```python
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
"""Specifies the retrieval mode:
- "local": Focuses on context-dependent information.
- "global": Utilizes global knowledge.
- "hybrid": Combines local and global retrieval methods.
- "naive": Performs a basic search without advanced techniques.
- "mix": Integrates knowledge graph and vector retrieval.
"""
only_need_context: bool = False
"""If True, only returns the retrieved context without generating a response."""
response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
top_k: int = 60
# Number of tokens for the original chunks.
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
...
```
> default value of Top_k can be change by environment variables TOP_K.

View File

@@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
print(f"BASE_URL: {BASE_URL}")
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -36,10 +40,12 @@ async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
LLM_MODEL,
prompt,
model=LLM_MODEL,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=BASE_URL,
api_key=API_KEY,
**kwargs,
)
@@ -49,8 +55,10 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
texts=texts,
model=EMBEDDING_MODEL,
base_url=BASE_URL,
api_key=API_KEY,
)

View File

@@ -0,0 +1,101 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
print(f"BASE_URL: {BASE_URL}")
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
model=LLM_MODEL,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=BASE_URL,
api_key=API_KEY,
**kwargs,
)
# Embedding function
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts=texts,
model=EMBEDDING_MODEL,
base_url=BASE_URL,
api_key=API_KEY,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,7 +1,7 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
WORKING_DIR = "./dickens"
@@ -10,6 +10,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete,
# llm_model_func=gpt_4o_complete
)

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.5"
__version__ = "1.1.6"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -27,30 +27,54 @@ T = TypeVar("T")
@dataclass
class QueryParam:
"""Configuration parameters for query execution in LightRAG."""
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
"""Specifies the retrieval mode:
- "local": Focuses on context-dependent information.
- "global": Utilizes global knowledge.
- "hybrid": Combines local and global retrieval methods.
- "naive": Performs a basic search without advanced techniques.
- "mix": Integrates knowledge graph and vector retrieval.
"""
only_need_context: bool = False
"""If True, only returns the retrieved context without generating a response."""
only_need_prompt: bool = False
"""If True, only returns the generated prompt without producing a response."""
response_type: str = "Multiple Paragraphs"
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
"""If True, enables streaming output for real-time responses."""
top_k: int = int(os.getenv("TOP_K", "60"))
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list)
"""List of high-level keywords to prioritize in retrieval."""
ll_keywords: list[str] = field(default_factory=list)
# Conversation history support
conversation_history: list[dict[str, str]] = field(
default_factory=list
) # Format: [{"role": "user/assistant", "content": "message"}]
history_turns: int = (
3 # Number of complete conversation turns (user-assistant pairs) to consider
)
"""List of low-level keywords to refine retrieval focus."""
conversation_history: list[dict[str, Any]] = field(default_factory=list)
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
history_turns: int = 3
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
@dataclass
@@ -202,3 +226,7 @@ class DocStatusStorage(BaseKVStorage):
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert."""
await self.upsert(data)

View File

@@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage):
if v["status"] == DocStatus.PENDING
}
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSED
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)

View File

@@ -530,6 +530,32 @@ class PGDocStatusStorage(DocStatusStorage):
)
return data
async def update_doc_status(self, data: dict[str, dict]) -> None:
"""
Updates only the document status, chunk count, and updated timestamp.
This method ensures that only relevant fields are updated instead of overwriting
the entire document record. If `updated_at` is not provided, the database will
automatically use the current timestamp.
"""
sql = """
UPDATE LIGHTRAG_DOC_STATUS
SET status = $3,
chunks_count = $4,
updated_at = CURRENT_TIMESTAMP
WHERE workspace = $1 AND id = $2
"""
for k, v in data.items():
_data = {
"workspace": self.db.workspace,
"id": k,
"status": v["status"].value, # Convert Enum to string
"chunks_count": v.get(
"chunks_count", -1
), # Default to -1 if not provided
}
await self.db.execute(sql, _data)
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""

View File

@@ -211,38 +211,65 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@dataclass
class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
working_dir: str = field(
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
)
# Default not to use embedding cache
embedding_cache_config: dict = field(
"""Directory where cache and temporary files are stored."""
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
kv_storage: str = field(default="JsonKVStorage")
"""Storage backend for key-value data."""
vector_storage: str = field(default="NanoVectorDBStorage")
"""Storage backend for vector embeddings."""
graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs."""
# logging
# Logging
current_log_level = logger.level
log_level: str = field(default=current_log_level)
log_level: int = field(default=current_log_level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_dir: str = field(default=os.getcwd())
"""Directory where logs are stored. Defaults to the current working directory."""
# text chunking
# Text chunking
chunk_token_size: int = 1200
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = 100
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini"
"""Model name used for tokenization when chunking text."""
# entity extraction
# Entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
"""Maximum number of entity extraction attempts for ambiguous content."""
# node embedding
entity_summary_to_max_tokens: int = 500
"""Maximum number of tokens used for summarizing extracted entities."""
# Node embedding
node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field(
"""Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
@@ -252,26 +279,56 @@ class LightRAG:
"random_seed": 3,
}
)
"""Configuration for the node2vec embedding algorithm:
- dimensions: Number of dimensions for embeddings.
- num_walks: Number of random walks per node.
- walk_length: Number of steps per random walk.
- window_size: Context window size for training.
- iterations: Number of iterations for training.
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc = None
"""Function for computing text embeddings. Must be set before use."""
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
embedding_batch_num: int = 32
"""Batch size for embedding computations."""
embedding_func_max_async: int = 16
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: callable = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
"""Name of the LLM model used for generating responses."""
# LLM
llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
llm_model_kwargs: dict = field(default_factory=dict)
"""Maximum number of tokens allowed per LLM response."""
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
"""Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
# Storage
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
# storage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments."""
enable_llm_cache: bool = True
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
"""Enables caching for LLM responses to avoid redundant computations."""
enable_llm_cache_for_entity_extract: bool = True
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
# extension
addon_params: dict[str, Any] = field(default_factory=dict)
@@ -279,8 +336,8 @@ class LightRAG:
convert_response_to_json
)
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Custom Chunking Function
chunking_func: Callable[
@@ -799,7 +856,7 @@ class LightRAG:
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
logger.info("No new unique documents were found.")
return
# 4. Store status document
@@ -816,15 +873,16 @@ class LightRAG:
each chunk for entity and relation extraction, and updating the
document status.
1. Get all pending and failed documents
1. Get all pending, failed, and abnormally terminated processing documents.
2. Split document content into chunks
3. Process each chunk for entity and relation extraction
4. Update the document status
"""
# 1. get all pending and failed documents
# 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {}
# Fetch failed documents
processing_docs = await self.doc_status.get_processing_docs()
to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs()
to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs()
@@ -855,6 +913,7 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
@@ -886,11 +945,15 @@ class LightRAG:
]
try:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
await self.doc_status.update_doc_status(
{
doc_status_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
@@ -899,11 +962,15 @@ class LightRAG:
except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert(
await self.doc_status.update_doc_status(
{
doc_status_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}

View File

@@ -103,17 +103,19 @@ async def openai_complete_if_cache(
) -> str:
if history_messages is None:
history_messages = []
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers)
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
else AsyncOpenAI(
base_url=base_url, default_headers=default_headers, api_key=api_key
)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
@@ -294,17 +296,19 @@ async def openai_embed(
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers)
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
else AsyncOpenAI(
base_url=base_url, default_headers=default_headers, api_key=api_key
)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"

View File

@@ -1504,7 +1504,7 @@ async def naive_query(
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "default", cache_type="query"
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if cached_response is not None:
return cached_response