Merge branch 'main' into yangdx

This commit is contained in:
yangdx
2025-01-16 20:20:09 +08:00
16 changed files with 1084 additions and 194 deletions

2
.gitignore vendored
View File

@@ -21,4 +21,4 @@ rag_storage
venv/
examples/input/
examples/output/
test_results.json
.DS_Store

View File

@@ -330,6 +330,26 @@ rag = LightRAG(
with open("./newText.txt") as f:
rag.insert(f.read())
```
### Separate Keyword Extraction
We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords.
##### How It Works?
The function operates by dividing the input into two parts:
- `User Query`
- `Prompt`
It then performs keyword extraction exclusively on the `user query`. This separation ensures that the extraction process is focused and relevant, unaffected by any additional language in the `prompt`. It also allows the `prompt` to serve purely for response formatting, maintaining the intent and clarity of the user's original question.
##### Usage Example
This `example` shows how to tailor the function for educational content, focusing on detailed explanations for older students.
```python
rag.query_with_separate_keyword_extraction(
query="Explain the law of gravity",
prompt="Provide a detailed explanation suitable for high school students studying physics.",
param=QueryParam(mode="hybrid")
)
```
### Using Neo4J for Storage
@@ -361,6 +381,7 @@ see test_neo4j.py for a working example.
### Using PostgreSQL for Storage
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
```

View File

@@ -0,0 +1,97 @@
"""
Sometimes you need to switch a storage solution, but you want to save LLM token and time.
This handy script helps you to copy the LLM caches from one storage solution to another.
(Not all the storage impl are supported)
"""
import asyncio
import logging
import os
from dotenv import load_dotenv
from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage
from lightrag.storage import JsonKVStorage
load_dotenv()
ROOT_DIR = os.environ.get("ROOT_DIR")
WORKING_DIR = f"{ROOT_DIR}/dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# AGE
os.environ["AGE_GRAPH_NAME"] = "chinese"
postgres_db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "r2",
}
)
async def copy_from_postgres_to_json():
await postgres_db.initdb()
from_llm_response_cache = PGKVStorage(
namespace="llm_response_cache",
global_config={"embedding_batch_num": 6},
embedding_func=None,
db=postgres_db,
)
to_llm_response_cache = JsonKVStorage(
namespace="llm_response_cache",
global_config={"working_dir": WORKING_DIR},
embedding_func=None,
)
kv = {}
for c_id in await from_llm_response_cache.all_keys():
print(f"Copying {c_id}")
workspace = c_id["workspace"]
mode = c_id["mode"]
_id = c_id["id"]
postgres_db.workspace = workspace
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
if mode not in kv:
kv[mode] = {}
kv[mode][_id] = obj[_id]
print(f"Object {obj}")
await to_llm_response_cache.upsert(kv)
await to_llm_response_cache.index_done_callback()
print("Mission accomplished!")
async def copy_from_json_to_postgres():
await postgres_db.initdb()
from_llm_response_cache = JsonKVStorage(
namespace="llm_response_cache",
global_config={"working_dir": WORKING_DIR},
embedding_func=None,
)
to_llm_response_cache = PGKVStorage(
namespace="llm_response_cache",
global_config={"embedding_batch_num": 6},
embedding_func=None,
db=postgres_db,
)
for mode in await from_llm_response_cache.all_keys():
print(f"Copying {mode}")
caches = await from_llm_response_cache.get_by_id(mode)
for k, v in caches.items():
item = {mode: {k: v}}
print(f"\tCopying {item}")
await to_llm_response_cache.upsert(item)
if __name__ == "__main__":
asyncio.run(copy_from_json_to_postgres())

View File

@@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
APIKEY = "ocigenerativeai"
CHATMODEL = "cohere.command-r-plus"
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
CHUNK_TOKEN_SIZE = 1024
MAX_TOKENS = 4000
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -86,30 +87,46 @@ async def main():
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
enable_llm_cache=False,
# log_level="DEBUG",
working_dir=WORKING_DIR,
chunk_token_size=512,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=True,
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size=MAX_TOKENS,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
max_token_size=500,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params={
"example_number": 1,
"language": "Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size": 2,
},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
rag.key_string_value_json_storage_cls.db = oracle_db
rag.vector_db_storage_cls.db = oracle_db
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
rag.set_storage_client(db_client=oracle_db)
# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read()
texts = [x for x in all_text.split("\n") if x]
# New mode use pipeline
await rag.apipeline_process_documents(texts)
await rag.apipeline_process_chunks()
await rag.apipeline_process_extract_graph()
# Old method use ainsert
# await rag.ainsert(texts)
# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]

View File

@@ -0,0 +1,116 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
from dotenv import load_dotenv
import logging
from openai import AzureOpenAI
logging.basicConfig(level=logging.INFO)
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_OPENAI_API_VERSION,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
chat_completion = client.chat.completions.create(
model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
)
return chat_completion.choices[0].message.content
async def embedding_func(texts: list[str]) -> np.ndarray:
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_EMBEDDING_API_VERSION,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
embeddings = [item.embedding for item in embedding.data]
return np.array(embeddings)
async def test_funcs():
result = await llm_model_func("How are you?")
print("Resposta do llm_model_func: ", result)
result = await embedding_func(["How are you?"])
print("Resultado do embedding_func: ", result.shape)
print("Dimensão da embedding: ", result.shape[1])
asyncio.run(test_funcs())
embedding_dimension = 3072
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")
rag.insert([book1.read(), book2.read()])
# Example function demonstrating the new query_with_separate_keyword_extraction usage
async def run_example():
query = "What are the top themes in this story?"
prompt = "Please simplify the response for a young audience."
# Using the new method to ensure the keyword extraction is only applied to the query
response = rag.query_with_separate_keyword_extraction(
query=query,
prompt=prompt,
param=QueryParam(mode="hybrid"), # Adjust QueryParam mode as necessary
)
print("Extracted Response:", response)
# Run the example asynchronously
if __name__ == "__main__":
asyncio.run(run_example())

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.1"
__version__ = "1.1.2"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -31,6 +31,8 @@ class QueryParam:
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
hl_keywords: list[str] = field(default_factory=list)
ll_keywords: list[str] = field(default_factory=list)
@dataclass

View File

@@ -153,8 +153,6 @@ class OracleDB:
if data is None:
await cursor.execute(sql)
else:
# print(data)
# print(sql)
await cursor.execute(sql, data)
await connection.commit()
except Exception as e:
@@ -167,35 +165,64 @@ class OracleDB:
@dataclass
class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
"""get doc_full data based on id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
# print("get_by_id:"+SQL)
res = await self.db.query(SQL, params)
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
else:
res = await self.db.query(SQL, params)
if res:
data = res # {"data":res}
# print (data)
return data
return res
else:
return None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
"""get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
# print("get_by_ids:"+SQL)
# print(params)
res = await self.db.query(SQL, params, multirows=True)
if "llm_response_cache" == self.namespace:
modes = set()
dict_res: dict[str, dict] = {}
for row in res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in res:
dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()]
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -203,38 +230,43 @@ class OracleKVStorage(BaseKVStorage):
else:
return None
async def get_by_status_and_ids(
self, status: str, ids: list[str]
) -> Union[list[dict], None]:
"""Specifically for llm_response_cache."""
if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
else:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, multirows=True)
if res:
return res
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
)
params = {"workspace": self.db.workspace}
try:
await self.db.query(SQL, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
res = await self.db.query(SQL, params, multirows=True)
data = None
if res:
exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys])
return data
else:
exist_keys = []
data = set([s for s in keys if s not in exist_keys])
return data
return set(keys)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
# print(self._data)
# values = []
if self.namespace == "text_chunks":
list_data = [
{
"__id__": k,
"id": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
@@ -250,35 +282,50 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# print(list_data)
merge_sql = SQL_TEMPLATES["merge_chunk"]
for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"]
data = {
"check_id": item["__id__"],
"id": item["__id__"],
_data = {
"id": item["id"],
"content": item["content"],
"workspace": self.db.workspace,
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": item["__vector__"],
"status": item["status"],
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, _data)
if self.namespace == "full_docs":
for k, v in self._data.items():
for k, v in data.items():
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"]
data = {
"check_id": k,
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
return left_data
await self.db.execute(merge_sql, _data)
if self.namespace == "llm_response_cache":
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"cache_mode": mode,
}
await self.db.execute(upsert_sql, _data)
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
@@ -287,6 +334,8 @@ class OracleKVStorage(BaseKVStorage):
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
db: OracleDB = None
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
@@ -328,7 +377,7 @@ class OracleGraphStorage(BaseGraphStorage):
def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################
@@ -362,7 +411,6 @@ class OracleGraphStorage(BaseGraphStorage):
"content": content,
"content_vector": content_vector,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
# self._graph.add_node(node_id, **node_data)
@@ -564,20 +612,26 @@ N_T = {
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
id varchar(256)PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
doc_name varchar(1024),
content CLOB,
meta JSON,
content_summary varchar(1024),
content_length NUMBER,
status varchar(256),
chunks_count NUMBER,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL
updatetime TIMESTAMP DEFAULT NULL,
error varchar(4096)
)"""
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id varchar(256) PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
full_doc_id varchar(256),
status varchar(256),
chunk_order_index NUMBER,
tokens NUMBER,
content CLOB,
@@ -619,9 +673,15 @@ TABLES = {
"LIGHTRAG_LLM_CACHE": {
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
id varchar(256) PRIMARY KEY,
send clob,
return clob,
model varchar(1024),
workspace varchar(1024),
cache_mode varchar(256),
model_name varchar(256),
original_prompt clob,
return_value clob,
embedding CLOB,
embedding_shape NUMBER,
embedding_min NUMBER,
embedding_max NUMBER,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL
)"""
@@ -646,23 +706,44 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)
""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL
ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN INSERT
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL
ON (a.id = :id)
WHEN NOT MATCHED THEN
INSERT (workspace,id,original_prompt,return_value,cache_mode)
VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
WHEN MATCHED THEN UPDATE
SET original_prompt = :original_prompt,
return_value = :return_value,
cache_mode = :cache_mode,
updatetime = SYSDATE""",
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
@@ -714,16 +795,22 @@ SQL_TEMPLATES = {
COLUMNS (a.name as source_name,b.name as target_name))""",
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
USING DUAL
ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
ON (a.workspace=:workspace and a.name=:name)
WHEN NOT MATCHED THEN
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
USING DUAL
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"get_all_nodes": """WITH t0 AS (
SELECT name AS id, entity_type AS label, entity_type, description,
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids

View File

@@ -231,6 +231,16 @@ class PGKVStorage(BaseKVStorage):
else:
return None
async def all_keys(self) -> list[dict]:
if "llm_response_cache" == self.namespace:
sql = "select workspace,mode,id from lightrag_llm_cache"
res = await self.db.query(sql, multirows=True)
return res
else:
logger.error(
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
)
async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
@@ -412,7 +422,10 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
keys = ",".join([f"'{_id}'" for _id in data])
sql = (
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})"
)
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None:

View File

@@ -17,6 +17,8 @@ from .operate import (
kg_query,
naive_query,
mix_kg_vector_query,
extract_keywords_only,
kg_query_with_keywords,
)
from .utils import (
@@ -26,6 +28,7 @@ from .utils import (
convert_response_to_json,
logger,
set_logger,
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -36,21 +39,30 @@ from .base import (
DocStatus,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonDocStatusStorage,
)
from .prompt import GRAPH_FIELD_SEP
# future KG integrations
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
STORAGES = {
"JsonKVStorage": ".storage",
"NanoVectorDBStorage": ".storage",
"NetworkXStorage": ".storage",
"JsonDocStatusStorage": ".storage",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
}
def lazy_external_import(module_name: str, class_name: str):
@@ -66,34 +78,13 @@ def lazy_external_import(module_name: str, class_name: str):
def import_class(*args, **kwargs):
import importlib
# Import the module using importlib
module = importlib.import_module(module_name, package=package)
# Get the class from the module and instantiate it
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
@@ -197,34 +188,51 @@ class LightRAG:
logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# @TODO: should move all storage setup here to leverage initial start params attached to self.
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class()[self.kv_storage]
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
self.vector_storage
]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.graph_storage
]
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
# show config
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
)
self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls, global_config=global_config
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace="json_doc_status_storage",
embedding_func=None,
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
)
####
@@ -232,17 +240,14 @@ class LightRAG:
####
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
####
@@ -251,72 +256,69 @@ class LightRAG:
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=asdict(self),
global_config=global_config,
embedding_func=None,
)
def _get_storage_class(self) -> dict:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
"OracleKVStorage": OracleKVStorage,
"MongoKVStorage": MongoKVStorage,
"TiDBKVStorage": TiDBKVStorage,
# vector storage
"NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage": OracleVectorDBStorage,
"MilvusVectorDBStorge": MilvusVectorDBStorge,
"ChromaVectorDBStorage": ChromaVectorDBStorage,
"TiDBVectorDBStorage": TiDBVectorDBStorage,
# graph storage
"NetworkXStorage": NetworkXStorage,
"Neo4JStorage": Neo4JStorage,
"OracleGraphStorage": OracleGraphStorage,
"AGEStorage": AGEStorage,
"PGGraphStorage": PGGraphStorage,
"PGKVStorage": PGKVStorage,
"PGDocStatusStorage": PGDocStatusStorage,
"PGVectorStorage": PGVectorStorage,
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
}
def _get_storage_class(self, storage_name: str) -> dict:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self, db_client):
# Now only tested on Oracle Database
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
def insert(
self, string_or_strings, split_by_character=None, split_by_character_only=False
@@ -538,6 +540,195 @@ class LightRAG:
if update_storage:
await self._insert_done()
async def apipeline_process_documents(self, string_or_strings):
"""Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
)
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": None,
}
for content in unique_contents
}
# 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
logger.info(
f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents"
)
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return None
# 4. Store original document
for doc_id, doc in new_docs.items():
await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
await self.full_docs.change_status(doc_id, DocStatus.PENDING)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(self):
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
_todo_doc_keys = []
_failed_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.FAILED, ids=None
)
_pendding_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
if not _todo_doc_keys:
logger.info("All documents have been processed or are duplicates")
return None
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
new_docs = {
doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
}
# 2. split docs into chunks, insert chunks, update doc status
chunk_cnt = 0
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(),
desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}",
):
try:
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
"status": DocStatus.PENDING,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
chunk_cnt += len(chunks)
await self.text_chunks.upsert(chunks)
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Update doc status
await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
except Exception as e:
# Mark as failed if any step fails
await self.full_docs.change_status(doc_id, DocStatus.FAILED)
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks
_todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.FAILED, ids=None
)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
if not _todo_chunk_keys:
logger.info("All chunks have been processed or are duplicates")
return None
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
semaphore = asyncio.Semaphore(
batch_size
) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id):
async with semaphore:
chunks = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
# Extract and store entities and relationships
try:
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.info("No entities or relationships extracted!")
# Update status to processed
await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
except Exception as e:
logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
raise e
with tqdm_async(
total=len(_todo_chunk_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0,
) as progress:
tasks = []
for chunk_id in _todo_chunk_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)
for future in asyncio.as_completed(tasks):
await future
progress.update(1)
progress.set_postfix(
{
"LLM call": statistic_data["llm_call"],
"LLM cache": statistic_data["llm_cache"],
}
)
# Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self):
tasks = []
for storage_inst in [
@@ -753,6 +944,114 @@ class LightRAG:
await self._query_done()
return response
def query_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Extract keywords from the 'query' using new function in operate.py.
2. Then run the standard aquery() flow with the final prompt (formatted_question).
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.aquery_with_separate_keyword_extraction(query, prompt, param)
)
async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
"""
# ---------------------
# STEP 1: Keyword Extraction
# ---------------------
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
global_config=asdict(self),
hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
param.hl_keywords = (hl_keywords,)
param.ll_keywords = (ll_keywords,)
# ---------------------
# STEP 2: Final Query Logic
# ---------------------
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query_with_keywords(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
elif param.mode == "naive":
response = await naive_query(
formatted_question,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:

View File

@@ -20,6 +20,7 @@ from .utils import (
handle_cache,
save_to_cache,
CacheData,
statistic_data,
)
from .base import (
BaseGraphStorage,
@@ -96,6 +97,10 @@ async def _handle_entity_relation_summary(
description: str,
global_config: dict,
) -> str:
"""Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description.
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
@@ -176,6 +181,7 @@ async def _merge_nodes_then_upsert(
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = []
already_source_ids = []
already_description = []
@@ -356,7 +362,7 @@ async def extract_entities(
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages)
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
@@ -368,8 +374,10 @@ async def extract_entities(
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
@@ -388,6 +396,11 @@ async def extract_entities(
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
""" "Prpocess a single chunk
Args:
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
"""
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
@@ -451,10 +464,8 @@ async def extract_entities(
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
logger.debug(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
@@ -462,8 +473,10 @@ async def extract_entities(
for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Extracting entities from chunks",
desc="Level 2 - Extracting entities and relationships",
unit="chunk",
position=1,
leave=False,
):
results.append(await result)
@@ -474,7 +487,7 @@ async def extract_entities(
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
logger.info("Inserting entities into storage...")
logger.debug("Inserting entities into storage...")
all_entities_data = []
for result in tqdm_async(
asyncio.as_completed(
@@ -484,12 +497,14 @@ async def extract_entities(
]
),
total=len(maybe_nodes),
desc="Inserting entities",
desc="Level 3 - Inserting entities",
unit="entity",
position=2,
leave=False,
):
all_entities_data.append(await result)
logger.info("Inserting relationships into storage...")
logger.debug("Inserting relationships into storage...")
all_relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
@@ -501,8 +516,10 @@ async def extract_entities(
]
),
total=len(maybe_edges),
desc="Inserting relationships",
desc="Level 3 - Inserting relationships",
unit="relationship",
position=3,
leave=False,
):
all_relationships_data.append(await result)
@@ -681,6 +698,219 @@ async def kg_query(
return response
async def kg_query_with_keywords(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
Then it uses those to build context and produce a final LLM response.
"""
# ---------------------------
# 0) Handle potential cache
# ---------------------------
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
# ---------------------------
# 1) RETRIEVE KEYWORDS FROM query_param
# ---------------------------
# If these fields don't exist, default to empty lists/strings.
hl_keywords = getattr(query_param, "hl_keywords", []) or []
ll_keywords = getattr(query_param, "ll_keywords", []) or []
# If neither has any keywords, you could handle that logic here.
if not hl_keywords and not ll_keywords:
logger.warning(
"No keywords found in query_param. Could default to global mode or fail."
)
return PROMPTS["fail_response"]
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty, switching to global mode.")
query_param.mode = "global"
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty, switching to local mode.")
query_param.mode = "local"
# Flatten low-level and high-level keywords if needed
ll_keywords_flat = (
[item for sublist in ll_keywords for item in sublist]
if any(isinstance(i, list) for i in ll_keywords)
else ll_keywords
)
hl_keywords_flat = (
[item for sublist in hl_keywords for item in sublist]
if any(isinstance(i, list) for i in hl_keywords)
else hl_keywords
)
# Join the flattened lists
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
keywords = [ll_keywords_str, hl_keywords_str]
logger.info("Using %s mode for query processing", query_param.mode)
# ---------------------------
# 2) BUILD CONTEXT
# ---------------------------
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if not context:
return PROMPTS["fail_response"]
# If only context is needed, return it
if query_param.only_need_context:
return context
# ---------------------------
# 3) BUILD THE SYSTEM PROMPT + CALL LLM
# ---------------------------
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
# Now call the LLM with the final system prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
# Clean up the response
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# ---------------------------
# 4) SAVE TO CACHE
# ---------------------------
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode,
),
)
return response
async def extract_keywords_only(
text: str,
param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
This method does NOT build the final RAG context or provide a final answer.
It ONLY extracts keywords (hl_keywords, ll_keywords).
"""
# 1. Handle cache if needed
args_hash = compute_args_hash(param.mode, text)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode
)
if cached_response is not None:
# parse the cached_response if its JSON containing keywords
# or simply return (hl_keywords, ll_keywords) from cached
# Assuming cached_response is in the same JSON structure:
match = re.search(r"\{.*\}", cached_response, re.DOTALL)
if match:
keywords_data = json.loads(match.group(0))
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
return hl_keywords, ll_keywords
return [], []
# 2. Build the examples
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
# 3. Build the keyword-extraction prompt
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
# 4. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"]
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 5. Parse out JSON from the LLM response
match = re.search(r"\{.*\}", result, re.DOTALL)
if not match:
logger.error("No JSON-like structure found in the result.")
return [], []
try:
keywords_data = json.loads(match.group(0))
except json.JSONDecodeError as e:
logger.error(f"JSON parsing error: {e}")
return [], []
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# 6. Cache the result if needed
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
prompt=text,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=param.mode,
),
)
return hl_keywords, ll_keywords
async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,

View File

@@ -30,13 +30,18 @@ class UnlimitedSemaphore:
ENCODER = None
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
logger = logging.getLogger("lightrag")
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
@@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return None, None, None, None
# For naive mode, only use simple cache matching
if mode == "naive":
# if mode == "naive":
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
@@ -473,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
embedding_model_func = hashing_kv.global_config[
"embedding_func"
].func # ["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])