support TiDB: add TiDBKVStorage, TiDBVectorDBStorage
This commit is contained in:
127
examples/lightrag_tidb_demo.py
Normal file
127
examples/lightrag_tidb_demo.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.kg.tidb_impl import TiDB
|
||||||
|
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
# We use SiliconCloud API to call LLM on Oracle Cloud
|
||||||
|
# More docs here https://docs.siliconflow.cn/introduction
|
||||||
|
BASE_URL = "https://api.siliconflow.cn/v1/"
|
||||||
|
APIKEY = ""
|
||||||
|
CHATMODEL = ""
|
||||||
|
EMBEDMODEL = ""
|
||||||
|
|
||||||
|
TIDB_HOST = ""
|
||||||
|
TIDB_PORT = ""
|
||||||
|
TIDB_USER = ""
|
||||||
|
TIDB_PASSWORD = ""
|
||||||
|
TIDB_DATABASE = ""
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
CHATMODEL,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=APIKEY,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await siliconcloud_embedding(
|
||||||
|
texts,
|
||||||
|
# model=EMBEDMODEL,
|
||||||
|
api_key=APIKEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_dim():
|
||||||
|
test_text = ["This is a test sentence."]
|
||||||
|
embedding = await embedding_func(test_text)
|
||||||
|
embedding_dim = embedding.shape[1]
|
||||||
|
return embedding_dim
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
try:
|
||||||
|
# Detect embedding dimension
|
||||||
|
embedding_dimension = await get_embedding_dim()
|
||||||
|
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||||
|
|
||||||
|
# Create TiDB DB connection
|
||||||
|
tidb = TiDB(
|
||||||
|
config={
|
||||||
|
"host": TIDB_HOST,
|
||||||
|
"port": TIDB_PORT,
|
||||||
|
"user": TIDB_USER,
|
||||||
|
"password": TIDB_PASSWORD,
|
||||||
|
"database": TIDB_DATABASE,
|
||||||
|
"workspace": "company", # specify which docs you want to store and query
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if TiDB DB tables exist, if not, tables will be created
|
||||||
|
await tidb.check_tables()
|
||||||
|
|
||||||
|
# Initialize LightRAG
|
||||||
|
# We use TiDB DB as the KV/vector
|
||||||
|
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
||||||
|
rag = LightRAG(
|
||||||
|
enable_llm_cache=False,
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
chunk_token_size=512,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=embedding_dimension,
|
||||||
|
max_token_size=512,
|
||||||
|
func=embedding_func,
|
||||||
|
),
|
||||||
|
kv_storage="TiDBKVStorage",
|
||||||
|
vector_storage="TiDBVectorDBStorage",
|
||||||
|
)
|
||||||
|
|
||||||
|
if rag.llm_response_cache:
|
||||||
|
rag.llm_response_cache.db = tidb
|
||||||
|
rag.full_docs.db = tidb
|
||||||
|
rag.text_chunks.db = tidb
|
||||||
|
rag.entities_vdb.db = tidb
|
||||||
|
rag.relationships_vdb.db = tidb
|
||||||
|
rag.chunks_vdb.db = tidb
|
||||||
|
|
||||||
|
# Extract and Insert into LightRAG storage
|
||||||
|
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
||||||
|
await rag.ainsert(f.read())
|
||||||
|
|
||||||
|
# Perform search in different modes
|
||||||
|
modes = ["naive", "local", "global", "hybrid"]
|
||||||
|
for mode in modes:
|
||||||
|
print("=" * 20, mode, "=" * 20)
|
||||||
|
print(
|
||||||
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode=mode),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("-" * 100, "\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
446
lightrag/kg/tidb_impl.py
Normal file
446
lightrag/kg/tidb_impl.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lightrag.base import BaseVectorStorage, BaseKVStorage
|
||||||
|
from lightrag.utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
class TiDB(object):
|
||||||
|
def __init__(self, config, **kwargs):
|
||||||
|
self.host = config.get("host", None)
|
||||||
|
self.port = config.get("port", None)
|
||||||
|
self.user = config.get("user", None)
|
||||||
|
self.password = config.get("password", None)
|
||||||
|
self.database = config.get("database", None)
|
||||||
|
self.workspace = config.get("workspace", None)
|
||||||
|
connection_string = (f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||||
|
f"?ssl_verify_cert=true&ssl_verify_identity=true")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.engine = create_engine(connection_string)
|
||||||
|
logger.info(f"Connected to TiDB database at {self.database}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to TiDB database at {self.database}")
|
||||||
|
logger.error(f"TiDB database error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def check_tables(self):
|
||||||
|
for k, v in TABLES.items():
|
||||||
|
try:
|
||||||
|
await self.query(f"SELECT 1 FROM {k}".format(k=k))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check table {k} in TiDB database")
|
||||||
|
logger.error(f"TiDB database error: {e}")
|
||||||
|
try:
|
||||||
|
# print(v["ddl"])
|
||||||
|
await self.execute(v["ddl"])
|
||||||
|
logger.info(f"Created table {k} in TiDB database")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create table {k} in TiDB database")
|
||||||
|
logger.error(f"TiDB database error: {e}")
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self, sql: str, params: dict = None, multirows: bool = False
|
||||||
|
) -> Union[dict, None]:
|
||||||
|
if params is None:
|
||||||
|
params = { "workspace": self.workspace }
|
||||||
|
else:
|
||||||
|
params.update({"workspace": self.workspace})
|
||||||
|
with self.engine.connect() as conn, conn.begin():
|
||||||
|
try:
|
||||||
|
result = conn.execute(text(sql), params)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tidb database error: {e}")
|
||||||
|
print(sql)
|
||||||
|
print(params)
|
||||||
|
raise
|
||||||
|
if multirows:
|
||||||
|
rows = result.all()
|
||||||
|
if rows:
|
||||||
|
data = [dict(zip(result.keys(), row)) for row in rows]
|
||||||
|
else:
|
||||||
|
data = []
|
||||||
|
else:
|
||||||
|
row = result.first()
|
||||||
|
if row:
|
||||||
|
data = dict(zip(result.keys(), row))
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def execute(self, sql: str, data: list | dict = None):
|
||||||
|
# logger.info("go into TiDBDB execute method")
|
||||||
|
try:
|
||||||
|
with self.engine.connect() as conn, conn.begin():
|
||||||
|
if data is None:
|
||||||
|
conn.execute(text(sql))
|
||||||
|
else:
|
||||||
|
conn.execute(text(sql), parameters=data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TiDB database error: {e}")
|
||||||
|
print(sql)
|
||||||
|
print(data)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
|
# should pass db object to self.db
|
||||||
|
def __post_init__(self):
|
||||||
|
self._data = {}
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
|
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||||
|
"""根据 id 获取 doc_full 数据."""
|
||||||
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
|
params = {"id": id}
|
||||||
|
# print("get_by_id:"+SQL)
|
||||||
|
res = await self.db.query(SQL, params)
|
||||||
|
if res:
|
||||||
|
data = res # {"data":res}
|
||||||
|
# print (data)
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Query by id
|
||||||
|
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||||
|
"""根据 id 获取 doc_chunks 数据"""
|
||||||
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
|
)
|
||||||
|
# print("get_by_ids:"+SQL)
|
||||||
|
res = await self.db.query(SQL, multirows=True)
|
||||||
|
if res:
|
||||||
|
data = res # [{"data":i} for i in res]
|
||||||
|
# print(data)
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||||
|
"""过滤掉重复内容"""
|
||||||
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
|
table_name=N_T[self.namespace],
|
||||||
|
id_field= N_ID[self.namespace],
|
||||||
|
ids=",".join([f"'{id}'" for id in keys])
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self.db.query(SQL)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tidb database error: {e}")
|
||||||
|
print(SQL)
|
||||||
|
res = await self.db.query(SQL, multirows=True)
|
||||||
|
if res:
|
||||||
|
exist_keys = [key["id"] for key in res]
|
||||||
|
data = set([s for s in keys if s not in exist_keys])
|
||||||
|
else:
|
||||||
|
exist_keys = []
|
||||||
|
data = set([s for s in keys if s not in exist_keys])
|
||||||
|
return data
|
||||||
|
|
||||||
|
################ INSERT full_doc AND chunks ################
|
||||||
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
|
self._data.update(left_data)
|
||||||
|
if self.namespace == "text_chunks":
|
||||||
|
list_data = [
|
||||||
|
{
|
||||||
|
"__id__": k,
|
||||||
|
**{k1: v1 for k1, v1 in v.items()},
|
||||||
|
}
|
||||||
|
for k, v in data.items()
|
||||||
|
]
|
||||||
|
contents = [v["content"] for v in data.values()]
|
||||||
|
batches = [
|
||||||
|
contents[i: i + self._max_batch_size]
|
||||||
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
|
]
|
||||||
|
embeddings_list = await asyncio.gather(
|
||||||
|
*[self.embedding_func(batch) for batch in batches]
|
||||||
|
)
|
||||||
|
embeddings = np.concatenate(embeddings_list)
|
||||||
|
for i, d in enumerate(list_data):
|
||||||
|
d["__vector__"] = embeddings[i]
|
||||||
|
|
||||||
|
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
||||||
|
data = []
|
||||||
|
for item in list_data:
|
||||||
|
data.append({
|
||||||
|
"id": item["__id__"],
|
||||||
|
"content": item["content"],
|
||||||
|
"tokens": item["tokens"],
|
||||||
|
"chunk_order_index": item["chunk_order_index"],
|
||||||
|
"full_doc_id": item["full_doc_id"],
|
||||||
|
"content_vector": f"{item["__vector__"].tolist()}",
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
})
|
||||||
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
|
if self.namespace == "full_docs":
|
||||||
|
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
||||||
|
data = []
|
||||||
|
for k, v in self._data.items():
|
||||||
|
data.append({
|
||||||
|
"id": k,
|
||||||
|
"content": v["content"],
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
})
|
||||||
|
await self.db.execute(merge_sql, data)
|
||||||
|
return left_data
|
||||||
|
|
||||||
|
async def index_done_callback(self):
|
||||||
|
if self.namespace in ["full_docs", "text_chunks"]:
|
||||||
|
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = 0.2
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._client_file_name = os.path.join(
|
||||||
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
|
)
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
self.cosine_better_than_threshold = self.global_config.get(
|
||||||
|
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
|
""" search from tidb vector"""
|
||||||
|
|
||||||
|
embeddings = await self.embedding_func([query])
|
||||||
|
embedding = embeddings[0]
|
||||||
|
|
||||||
|
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"embedding_string": embedding_string,
|
||||||
|
"top_k": top_k,
|
||||||
|
"better_than_threshold": self.cosine_better_than_threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
results = await self.db.query(SQL_TEMPLATES[self.namespace], params=params, multirows=True)
|
||||||
|
print("vector search result:",results)
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
return results
|
||||||
|
|
||||||
|
###### INSERT entities And relationships ######
|
||||||
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
# ignore, upsert in TiDBKVStorage already
|
||||||
|
if not len(data):
|
||||||
|
logger.warning("You insert an empty data to vector DB")
|
||||||
|
return []
|
||||||
|
if self.namespace == "chunks":
|
||||||
|
return []
|
||||||
|
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
|
|
||||||
|
list_data = [
|
||||||
|
{
|
||||||
|
"id": k,
|
||||||
|
**{k1: v1 for k1, v1 in v.items()},
|
||||||
|
}
|
||||||
|
for k, v in data.items()
|
||||||
|
]
|
||||||
|
contents = [v["content"] for v in data.values()]
|
||||||
|
batches = [
|
||||||
|
contents[i: i + self._max_batch_size]
|
||||||
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
|
]
|
||||||
|
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||||
|
embeddings_list = []
|
||||||
|
for f in tqdm(
|
||||||
|
asyncio.as_completed(embedding_tasks),
|
||||||
|
total=len(embedding_tasks),
|
||||||
|
desc="Generating embeddings",
|
||||||
|
unit="batch",
|
||||||
|
):
|
||||||
|
embeddings = await f
|
||||||
|
embeddings_list.append(embeddings)
|
||||||
|
embeddings = np.concatenate(embeddings_list)
|
||||||
|
for i, d in enumerate(list_data):
|
||||||
|
d["content_vector"] = embeddings[i]
|
||||||
|
|
||||||
|
if self.namespace == "entities":
|
||||||
|
data = []
|
||||||
|
for item in list_data:
|
||||||
|
merge_sql = SQL_TEMPLATES["upsert_entity"]
|
||||||
|
data.append({
|
||||||
|
"id": item["id"],
|
||||||
|
"name": item["entity_name"],
|
||||||
|
"content": item["content"],
|
||||||
|
"content_vector": f"{item["content_vector"].tolist()}",
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
})
|
||||||
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
|
elif self.namespace == "relationships":
|
||||||
|
data = []
|
||||||
|
for item in list_data:
|
||||||
|
merge_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||||
|
data.append({
|
||||||
|
"id": item["id"],
|
||||||
|
"source_name": item["src_id"],
|
||||||
|
"target_name": item["tgt_id"],
|
||||||
|
"content": item["content"],
|
||||||
|
"content_vector": f"{item["content_vector"].tolist()}",
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
})
|
||||||
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
|
|
||||||
|
N_T = {
|
||||||
|
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||||
|
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
|
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
|
"entities": "LIGHTRAG_GRAPH_NODES",
|
||||||
|
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
||||||
|
}
|
||||||
|
N_ID = {
|
||||||
|
"full_docs": "doc_id",
|
||||||
|
"text_chunks": "chunk_id",
|
||||||
|
"chunks": "chunk_id",
|
||||||
|
"entities": "entity_id",
|
||||||
|
"relationships": "relation_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
TABLES = {
|
||||||
|
"LIGHTRAG_DOC_FULL": {
|
||||||
|
"ddl": """
|
||||||
|
CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||||
|
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||||
|
`doc_id` VARCHAR(256) NOT NULL,
|
||||||
|
`workspace` varchar(1024),
|
||||||
|
`content` LONGTEXT,
|
||||||
|
`meta` JSON,
|
||||||
|
`createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`updatetime` TIMESTAMP DEFAULT NULL,
|
||||||
|
UNIQUE KEY (`doc_id`)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
"LIGHTRAG_DOC_CHUNKS": {
|
||||||
|
"ddl": """
|
||||||
|
CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||||
|
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||||
|
`chunk_id` VARCHAR(256) NOT NULL,
|
||||||
|
`full_doc_id` VARCHAR(256) NOT NULL,
|
||||||
|
`workspace` varchar(1024),
|
||||||
|
`chunk_order_index` INT,
|
||||||
|
`tokens` INT,
|
||||||
|
`content` LONGTEXT,
|
||||||
|
`content_vector` VECTOR,
|
||||||
|
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`updatetime` DATETIME DEFAULT NULL,
|
||||||
|
UNIQUE KEY (`chunk_id`)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
"LIGHTRAG_GRAPH_NODES": {
|
||||||
|
"ddl":
|
||||||
|
"""
|
||||||
|
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
||||||
|
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||||
|
`entity_id` VARCHAR(256) NOT NULL,
|
||||||
|
`workspace` varchar(1024),
|
||||||
|
`name` VARCHAR(2048),
|
||||||
|
`content` LONGTEXT,
|
||||||
|
`content_vector` VECTOR,
|
||||||
|
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`updatetime` DATETIME DEFAULT NULL,
|
||||||
|
UNIQUE KEY (`entity_id`)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
"LIGHTRAG_GRAPH_EDGES": {
|
||||||
|
"ddl":
|
||||||
|
"""
|
||||||
|
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
||||||
|
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||||
|
`relation_id` VARCHAR(256) NOT NULL,
|
||||||
|
`workspace` varchar(1024),
|
||||||
|
`source_name` VARCHAR(2048),
|
||||||
|
`target_name` VARCHAR(2048),
|
||||||
|
`content` LONGTEXT,
|
||||||
|
`content_vector` VECTOR,
|
||||||
|
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`updatetime` DATETIME DEFAULT NULL,
|
||||||
|
UNIQUE KEY (`relation_id`)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
"LIGHTRAG_LLM_CACHE": {
|
||||||
|
"ddl": """
|
||||||
|
CREATE TABLE LIGHTRAG_LLM_CACHE (
|
||||||
|
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||||
|
send TEXT,
|
||||||
|
return TEXT,
|
||||||
|
model VARCHAR(1024),
|
||||||
|
createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updatetime DATETIME DEFAULT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SQL_TEMPLATES = {
|
||||||
|
# SQL for KVStorage
|
||||||
|
"get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
|
||||||
|
"get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
|
||||||
|
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
||||||
|
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
||||||
|
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
||||||
|
|
||||||
|
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
||||||
|
"upsert_doc_full": """
|
||||||
|
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
||||||
|
VALUES (:id, :content, :workspace)
|
||||||
|
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
|
""",
|
||||||
|
"upsert_chunk": """
|
||||||
|
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
||||||
|
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
||||||
|
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
|
""",
|
||||||
|
|
||||||
|
# SQL for VectorStorage
|
||||||
|
"entities": """SELECT n.name as entity_name FROM
|
||||||
|
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
||||||
|
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
|
||||||
|
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""",
|
||||||
|
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
|
||||||
|
(SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||||
|
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
|
||||||
|
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""",
|
||||||
|
"chunks": """SELECT c.id FROM
|
||||||
|
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||||
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
||||||
|
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
|
||||||
|
|
||||||
|
"upsert_entity": """
|
||||||
|
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
||||||
|
VALUES(:id, :name, :content, :content_vector, :workspace)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
||||||
|
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
|
""",
|
||||||
|
"upsert_relationship": """
|
||||||
|
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
||||||
|
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
||||||
|
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
|
"""
|
||||||
|
}
|
@@ -77,7 +77,8 @@ OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBS
|
|||||||
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
||||||
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
||||||
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
||||||
|
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
||||||
|
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
"""
|
"""
|
||||||
@@ -260,11 +261,13 @@ class LightRAG:
|
|||||||
"JsonKVStorage": JsonKVStorage,
|
"JsonKVStorage": JsonKVStorage,
|
||||||
"OracleKVStorage": OracleKVStorage,
|
"OracleKVStorage": OracleKVStorage,
|
||||||
"MongoKVStorage": MongoKVStorage,
|
"MongoKVStorage": MongoKVStorage,
|
||||||
|
"TiDBKVStorage": TiDBKVStorage,
|
||||||
# vector storage
|
# vector storage
|
||||||
"NanoVectorDBStorage": NanoVectorDBStorage,
|
"NanoVectorDBStorage": NanoVectorDBStorage,
|
||||||
"OracleVectorDBStorage": OracleVectorDBStorage,
|
"OracleVectorDBStorage": OracleVectorDBStorage,
|
||||||
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
||||||
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
||||||
|
"TiDBVectorDBStorage": TiDBVectorDBStorage,
|
||||||
# graph storage
|
# graph storage
|
||||||
"NetworkXStorage": NetworkXStorage,
|
"NetworkXStorage": NetworkXStorage,
|
||||||
"Neo4JStorage": Neo4JStorage,
|
"Neo4JStorage": Neo4JStorage,
|
||||||
|
@@ -16,6 +16,9 @@ pymongo
|
|||||||
pyvis
|
pyvis
|
||||||
tenacity
|
tenacity
|
||||||
# lmdeploy[all]
|
# lmdeploy[all]
|
||||||
|
sqlalchemy
|
||||||
|
pymysql
|
||||||
|
|
||||||
|
|
||||||
# LLM packages
|
# LLM packages
|
||||||
tiktoken
|
tiktoken
|
||||||
|
Reference in New Issue
Block a user