fix pre commit

This commit is contained in:
jin
2024-11-12 13:32:40 +08:00
parent 90790747ee
commit 33caba3e12
7 changed files with 345 additions and 310 deletions

View File

@@ -1,10 +1,10 @@
from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi import FastAPI, HTTPException, File, UploadFile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import sys, os import sys
import os
from pathlib import Path from pathlib import Path
import asyncio import asyncio
@@ -13,7 +13,6 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB from lightrag.kg.oracle_impl import OracleDB
@@ -24,8 +23,6 @@ script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory)) sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues # Apply nest_asyncio to solve event loop issues
nest_asyncio.apply() nest_asyncio.apply()
@@ -51,6 +48,7 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -80,8 +78,8 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1] embedding_dim = embedding.shape[1]
return embedding_dim return embedding_dim
async def init():
async def init():
# Detect embedding dimension # Detect embedding dimension
embedding_dimension = await get_embedding_dim() embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}") print(f"Detected embedding dimension: {embedding_dimension}")
@@ -91,36 +89,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
oracle_db = OracleDB(config={ config={
"user":"", "user": "",
"password":"", "password": "",
"dsn":"", "dsn": "",
"config_dir":"", "config_dir": "",
"wallet_location":"", "wallet_location": "",
"wallet_password":"", "wallet_password": "",
"workspace":"" "workspace": "",
} # specify which docs you want to store and query } # specify which docs you want to store and query
) )
# Check if Oracle DB tables exist, if not, tables will be created # Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables() await oracle_db.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
chunk_token_size=512, chunk_token_size=512,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension, embedding_dim=embedding_dimension,
max_token_size=512, max_token_size=512,
func=embedding_func, func=embedding_func,
), ),
graph_storage = "OracleGraphStorage", graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage", kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage" vector_storage="OracleVectorDBStorage",
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db rag.graph_storage_cls.db = oracle_db
@@ -129,6 +127,7 @@ async def init():
return rag return rag
# Data models # Data models
@@ -152,6 +151,7 @@ class Response(BaseModel):
rag = None # 定义为全局对象 rag = None # 定义为全局对象
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global rag global rag
@@ -160,18 +160,21 @@ async def lifespan(app: FastAPI):
yield yield
app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
@app.post("/query", response_model=Response) @app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest): async def query_endpoint(request: QueryRequest):
try: try:
# loop = asyncio.get_event_loop() # loop = asyncio.get_event_loop()
result = await rag.aquery( result = await rag.aquery(
request.query, request.query,
param=QueryParam( param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context mode=request.mode, only_need_context=request.only_need_context
), ),
) )
return Response(status="success", data=result) return Response(status="success", data=result)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,11 +1,11 @@
import sys, os import sys
import os
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd()) print(os.getcwd())
@@ -25,6 +25,7 @@ EMBEDMODEL = "cohere.embed-multilingual-v3.0"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -66,22 +67,21 @@ async def main():
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(config={ oracle_db = OracleDB(
"user":"username", config={
"password":"xxxxxxxxx", "user": "username",
"dsn":"xxxxxxx_medium", "password": "xxxxxxxxx",
"config_dir":"dir/path/to/oracle/config", "dsn": "xxxxxxx_medium",
"wallet_location":"dir/path/to/oracle/wallet", "config_dir": "dir/path/to/oracle/config",
"wallet_password":"xxxxxxxxx", "wallet_location": "dir/path/to/oracle/wallet",
"workspace":"company" # specify which docs you want to store and query "wallet_password": "xxxxxxxxx",
"workspace": "company", # specify which docs you want to store and query
} }
) )
# Check if Oracle DB tables exist, if not, tables will be created # Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables() await oracle_db.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
rag = LightRAG( rag = LightRAG(
@@ -93,10 +93,10 @@ async def main():
embedding_dim=embedding_dimension, embedding_dim=embedding_dimension,
max_token_size=512, max_token_size=512,
func=embedding_func, func=embedding_func,
), ),
graph_storage = "OracleGraphStorage", graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage", kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage" vector_storage="OracleVectorDBStorage",
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
@@ -106,14 +106,19 @@ async def main():
# Extract and Insert into LightRAG storage # Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f: with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read()) await rag.ainsert(f.read())
# Perform search in different modes # Perform search in different modes
modes = ["naive", "local", "global", "hybrid"] modes = ["naive", "local", "global", "hybrid"]
for mode in modes: for mode in modes:
print("="*20, mode, "="*20) print("=" * 20, mode, "=" * 20)
print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode))) print(
print("-"*100, "\n") await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode=mode),
)
)
print("-" * 100, "\n")
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")

View File

@@ -60,6 +60,7 @@ class BaseVectorStorage(StorageNameSpace):
@dataclass @dataclass
class BaseKVStorage(Generic[T], StorageNameSpace): class BaseKVStorage(Generic[T], StorageNameSpace):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
async def all_keys(self) -> list[str]: async def all_keys(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
@@ -85,6 +86,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None embedding_func: EmbeddingFunc = None
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,9 +1,9 @@
import asyncio import asyncio
#import html
#import os # import html
# import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Union
import networkx as nx
import numpy as np import numpy as np
import array import array
@@ -16,8 +16,9 @@ from ..base import (
import oracledb import oracledb
class OracleDB: class OracleDB:
def __init__(self,config,**kwargs): def __init__(self, config, **kwargs):
self.host = config.get("host", None) self.host = config.get("host", None)
self.port = config.get("port", None) self.port = config.get("port", None)
self.user = config.get("user", None) self.user = config.get("user", None)
@@ -37,16 +38,16 @@ class OracleDB:
oracledb.defaults.fetch_lobs = False oracledb.defaults.fetch_lobs = False
self.pool = oracledb.create_pool_async( self.pool = oracledb.create_pool_async(
user = self.user, user=self.user,
password = self.password, password=self.password,
dsn = self.dsn, dsn=self.dsn,
config_dir = self.config_dir, config_dir=self.config_dir,
wallet_location = self.wallet_location, wallet_location=self.wallet_location,
wallet_password = self.wallet_password, wallet_password=self.wallet_password,
min = 1, min=1,
max = self.max, max=self.max,
increment = self.increment increment=self.increment,
) )
logger.info(f"Connected to Oracle database at {self.dsn}") logger.info(f"Connected to Oracle database at {self.dsn}")
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Oracle database at {self.dsn}") logger.error(f"Failed to connect to Oracle database at {self.dsn}")
@@ -92,10 +93,12 @@ class OracleDB:
) )
async def check_tables(self): async def check_tables(self):
for k,v in TABLES.items(): for k, v in TABLES.items():
try: try:
if k.lower() == "lightrag_graph": if k.lower() == "lightrag_graph":
await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only") await self.query(
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
)
else: else:
await self.query("SELECT 1 FROM {k}".format(k=k)) await self.query("SELECT 1 FROM {k}".format(k=k))
except Exception as e: except Exception as e:
@@ -109,11 +112,10 @@ class OracleDB:
logger.error(f"Failed to create table {k} in Oracle database") logger.error(f"Failed to create table {k} in Oracle database")
logger.error(f"Oracle database error: {e}") logger.error(f"Oracle database error: {e}")
logger.info(f"Finished check all tables in Oracle database") logger.info("Finished check all tables in Oracle database")
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]: async with self.pool.acquire() as connection:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor: with connection.cursor() as cursor:
@@ -138,7 +140,7 @@ class OracleDB:
data = None data = None
return data return data
async def execute(self,sql: str, data: list = None): async def execute(self, sql: str, data: list = None):
# logger.info("go into OracleDB execute method") # logger.info("go into OracleDB execute method")
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@@ -148,9 +150,9 @@ class OracleDB:
if data is None: if data is None:
await cursor.execute(sql) await cursor.execute(sql)
else: else:
#print(data) # print(data)
#print(sql) # print(sql)
await cursor.execute(sql,data) await cursor.execute(sql, data)
await connection.commit() await connection.commit()
except Exception as e: except Exception as e:
logger.error(f"Oracle database error: {e}") logger.error(f"Oracle database error: {e}")
@@ -158,9 +160,9 @@ class OracleDB:
print(data) print(data)
raise raise
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db # should pass db object to self.db
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
@@ -170,36 +172,41 @@ class OracleKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict, None]: async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据.""" """根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id) SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
#print("get_by_id:"+SQL) workspace=self.db.workspace, id=id
)
# print("get_by_id:"+SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL)
if res: if res:
data = res #{"data":res} data = res # {"data":res}
#print (data) # print (data)
return data return data
else: else:
return None return None
# Query by id # Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]: async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据""" """根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace, SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])) workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
#print("get_by_ids:"+SQL) )
res = await self.db.query(SQL,multirows=True) # print("get_by_ids:"+SQL)
res = await self.db.query(SQL, multirows=True)
if res: if res:
data = res # [{"data":i} for i in res] data = res # [{"data":i} for i in res]
#print(data) # print(data)
return data return data
else: else:
return None return None
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容""" """过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], SQL = SQL_TEMPLATES["filter_keys"].format(
workspace=self.db.workspace, table_name=N_T[self.namespace],
ids=",".join([f"'{k}'" for k in keys])) workspace=self.db.workspace,
res = await self.db.query(SQL,multirows=True) ids=",".join([f"'{k}'" for k in keys]),
)
res = await self.db.query(SQL, multirows=True)
data = None data = None
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]
@@ -209,13 +216,12 @@ class OracleKVStorage(BaseKVStorage):
data = set([s for s in keys if s not in exist_keys]) data = set([s for s in keys if s not in exist_keys])
return data return data
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
#print(self._data) # print(self._data)
#values = [] # values = []
if self.namespace == "text_chunks": if self.namespace == "text_chunks":
list_data = [ list_data = [
{ {
@@ -226,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
] ]
contents = [v["content"] for v in data.values()] contents = [v["content"] for v in data.values()]
batches = [ batches = [
contents[i: i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embeddings_list = await asyncio.gather( embeddings_list = await asyncio.gather(
@@ -235,35 +241,38 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
#print(list_data) # print(list_data)
for item in list_data: for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"].format( merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
check_id=item["__id__"]
)
values = [item["__id__"], item["content"], self.db.workspace, item["tokens"], values = [
item["chunk_order_index"], item["full_doc_id"], item["__vector__"]] item["__id__"],
#print(merge_sql) item["content"],
self.db.workspace,
item["tokens"],
item["chunk_order_index"],
item["full_doc_id"],
item["__vector__"],
]
# print(merge_sql)
await self.db.execute(merge_sql, values) await self.db.execute(merge_sql, values)
if self.namespace == "full_docs": if self.namespace == "full_docs":
for k, v in self._data.items(): for k, v in self._data.items():
#values.clear() # values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"].format( merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
check_id=k, check_id=k,
) )
values = [k, self._data[k]["content"], self.db.workspace] values = [k, self._data[k]["content"], self.db.workspace]
#print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, values) await self.db.execute(merge_sql, values)
return left_data return left_data
async def index_done_callback(self): async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]: if self.namespace in ["full_docs", "text_chunks"]:
logger.info("full doc and chunk data had been saved into oracle db!") logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = 0.2
@@ -278,7 +287,6 @@ class OracleVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self): async def index_done_callback(self):
pass pass
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
"""从向量数据库中查询数据""" """从向量数据库中查询数据"""
@@ -287,19 +295,19 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度 # 转换精度
dtype = str(embedding.dtype).upper() dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0] dimension = embedding.shape[0]
embedding_string = ', '.join(map(str, embedding.tolist())) embedding_string = ", ".join(map(str, embedding.tolist()))
SQL = SQL_TEMPLATES[self.namespace].format( SQL = SQL_TEMPLATES[self.namespace].format(
embedding_string=embedding_string, embedding_string=embedding_string,
dimension=dimension, dimension=dimension,
dtype=dtype, dtype=dtype,
workspace=self.db.workspace, workspace=self.db.workspace,
top_k=top_k, top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold, better_than_threshold=self.cosine_better_than_threshold,
) )
# print(SQL) # print(SQL)
results = await self.db.query(SQL, multirows=True) results = await self.db.query(SQL, multirows=True)
#print("vector search result:",results) # print("vector search result:",results)
return results return results
@@ -311,20 +319,19 @@ class OracleGraphStorage(BaseGraphStorage):
"""从graphml文件加载图""" """从graphml文件加载图"""
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
#################### insert method ################ #################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""插入或更新节点""" """插入或更新节点"""
#print("go into upsert node method") # print("go into upsert node method")
entity_name = node_id entity_name = node_id
entity_type = node_data["entity_type"] entity_type = node_data["entity_type"]
description = node_data["description"] description = node_data["description"]
source_id = node_data["source_id"] source_id = node_data["source_id"]
content = entity_name+description content = entity_name + description
contents = [content] contents = [content]
batches = [ batches = [
contents[i: i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embeddings_list = await asyncio.gather( embeddings_list = await asyncio.gather(
@@ -333,27 +340,38 @@ class OracleGraphStorage(BaseGraphStorage):
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0] content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_node"].format( merge_sql = SQL_TEMPLATES["merge_node"].format(
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
) )
#print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector]) await self.db.execute(
#self._graph.add_node(node_id, **node_data) merge_sql,
[
self.db.workspace,
entity_name,
entity_type,
description,
source_id,
content,
content_vector,
],
)
# self._graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ):
"""插入或更新边""" """插入或更新边"""
#print("go into upsert edge method") # print("go into upsert edge method")
source_name = source_node_id source_name = source_node_id
target_name = target_node_id target_name = target_node_id
weight = edge_data["weight"] weight = edge_data["weight"]
keywords = edge_data["keywords"] keywords = edge_data["keywords"]
description = edge_data["description"] description = edge_data["description"]
source_chunk_id = edge_data["source_id"] source_chunk_id = edge_data["source_id"]
content = keywords+source_name+target_name+description content = keywords + source_name + target_name + description
contents = [content] contents = [content]
batches = [ batches = [
contents[i: i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embeddings_list = await asyncio.gather( embeddings_list = await asyncio.gather(
@@ -362,11 +380,27 @@ class OracleGraphStorage(BaseGraphStorage):
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0] content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_edge"].format( merge_sql = SQL_TEMPLATES["merge_edge"].format(
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id workspace=self.db.workspace,
source_name=source_name,
target_name=target_name,
source_chunk_id=source_chunk_id,
) )
#print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector]) await self.db.execute(
#self._graph.add_edge(source_node_id, target_node_id, **edge_data) merge_sql,
[
self.db.workspace,
source_name,
target_name,
weight,
keywords,
description,
source_chunk_id,
content,
content_vector,
],
)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""为节点生成向量""" """为节点生成向量"""
@@ -386,99 +420,109 @@ class OracleGraphStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
async def index_done_callback(self): async def index_done_callback(self):
"""写入graphhml图文件""" """写入graphhml图文件"""
logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!") logger.info(
"Node and edge data had been saved into oracle db already, so nothing to do here!"
)
#################### query method ################# #################### query method #################
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在""" """根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id) SQL = SQL_TEMPLATES["has_node"].format(
workspace=self.db.workspace, node_id=node_id
)
# print(SQL) # print(SQL)
#print(self.db.workspace, node_id) # print(self.db.workspace, node_id)
res = await self.db.query(SQL) res = await self.db.query(SQL)
if res: if res:
#print("Node exist!",res) # print("Node exist!",res)
return True return True
else: else:
#print("Node not exist!") # print("Node not exist!")
return False return False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在""" """根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace, SQL = SQL_TEMPLATES["has_edge"].format(
source_node_id=source_node_id, workspace=self.db.workspace,
target_node_id=target_node_id) source_node_id=source_node_id,
target_node_id=target_node_id,
)
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL)
if res: if res:
#print("Edge exist!",res) # print("Edge exist!",res)
return True return True
else: else:
#print("Edge not exist!") # print("Edge not exist!")
return False return False
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度""" """根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id) SQL = SQL_TEMPLATES["node_degree"].format(
workspace=self.db.workspace, node_id=node_id
)
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL)
if res: if res:
#print("Node degree",res["degree"]) # print("Node degree",res["degree"])
return res["degree"] return res["degree"]
else: else:
#print("Edge not exist!") # print("Edge not exist!")
return 0 return 0
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""根据源和目标节点id获取边的度""" """根据源和目标节点id获取边的度"""
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
#print("Edge degree",degree) # print("Edge degree",degree)
return degree return degree
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据""" """根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id) SQL = SQL_TEMPLATES["get_node"].format(
workspace=self.db.workspace, node_id=node_id
)
# print(self.db.workspace, node_id) # print(self.db.workspace, node_id)
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL)
if res: if res:
#print("Get node!",self.db.workspace, node_id,res) # print("Get node!",self.db.workspace, node_id,res)
return res return res
else: else:
#print("Can't get node!",self.db.workspace, node_id) # print("Can't get node!",self.db.workspace, node_id)
return None return None
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
"""根据源和目标节点id获取边""" """根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace, SQL = SQL_TEMPLATES["get_edge"].format(
source_node_id=source_node_id, workspace=self.db.workspace,
target_node_id=target_node_id) source_node_id=source_node_id,
target_node_id=target_node_id,
)
res = await self.db.query(SQL) res = await self.db.query(SQL)
if res: if res:
#print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res return res
else: else:
#print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
return None return None
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边""" """根据节点id获取节点的所有边"""
if await self.has_node(source_node_id): if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace, SQL = SQL_TEMPLATES["get_node_edges"].format(
source_node_id=source_node_id) workspace=self.db.workspace, source_node_id=source_node_id
)
res = await self.db.query(sql=SQL, multirows=True) res = await self.db.query(sql=SQL, multirows=True)
if res: if res:
data = [(i["source_name"],i["target_name"]) for i in res] data = [(i["source_name"], i["target_name"]) for i in res]
#print("Get node edge!",self.db.workspace, source_node_id,data) # print("Get node edge!",self.db.workspace, source_node_id,data)
return data return data
else: else:
#print("Node Edge not exist!",self.db.workspace, source_node_id) # print("Node Edge not exist!",self.db.workspace, source_node_id)
return [] return []
@@ -487,12 +531,12 @@ N_T = {
"text_chunks": "LIGHTRAG_DOC_CHUNKS", "text_chunks": "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS", "chunks": "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_GRAPH_NODES", "entities": "LIGHTRAG_GRAPH_NODES",
"relationships": "LIGHTRAG_GRAPH_EDGES" "relationships": "LIGHTRAG_GRAPH_EDGES",
} }
TABLES = { TABLES = {
"LIGHTRAG_DOC_FULL": "LIGHTRAG_DOC_FULL": {
{"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL ( "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
id varchar(256)PRIMARY KEY, id varchar(256)PRIMARY KEY,
workspace varchar(1024), workspace varchar(1024),
doc_name varchar(1024), doc_name varchar(1024),
@@ -500,10 +544,10 @@ TABLES = {
meta JSON, meta JSON,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL updatetime TIMESTAMP DEFAULT NULL
)"""}, )"""
},
"LIGHTRAG_DOC_CHUNKS": "LIGHTRAG_DOC_CHUNKS": {
{"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS ( "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id varchar(256) PRIMARY KEY, id varchar(256) PRIMARY KEY,
workspace varchar(1024), workspace varchar(1024),
full_doc_id varchar(256), full_doc_id varchar(256),
@@ -513,10 +557,10 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL updatetime TIMESTAMP DEFAULT NULL
)"""}, )"""
},
"LIGHTRAG_GRAPH_NODES": "LIGHTRAG_GRAPH_NODES": {
{"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES ( "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
workspace varchar(1024), workspace varchar(1024),
name varchar(2048), name varchar(2048),
@@ -527,9 +571,10 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL updatetime TIMESTAMP DEFAULT NULL
)"""}, )"""
"LIGHTRAG_GRAPH_EDGES": },
{"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES ( "LIGHTRAG_GRAPH_EDGES": {
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
workspace varchar(1024), workspace varchar(1024),
source_name varchar(2048), source_name varchar(2048),
@@ -542,19 +587,20 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL updatetime TIMESTAMP DEFAULT NULL
)"""}, )"""
"LIGHTRAG_LLM_CACHE": },
{"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE ( "LIGHTRAG_LLM_CACHE": {
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
id varchar(256) PRIMARY KEY, id varchar(256) PRIMARY KEY,
send clob, send clob,
return clob, return clob,
model varchar(1024), model varchar(1024),
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL updatetime TIMESTAMP DEFAULT NULL
)"""}, )"""
},
"LIGHTRAG_GRAPH": "LIGHTRAG_GRAPH": {
{"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
VERTEX TABLES ( VERTEX TABLES (
lightrag_graph_nodes KEY (id) lightrag_graph_nodes KEY (id)
LABEL entity LABEL entity
@@ -566,92 +612,66 @@ TABLES = {
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name) DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
LABEL has_relation LABEL has_relation
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""}, ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
} },
}
SQL_TEMPLATES = { SQL_TEMPLATES = {
# SQL for KVStorage # SQL for KVStorage
"get_by_id_full_docs": "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
"get_by_id_text_chunks": "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
"get_by_ids_full_docs":
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
"get_by_ids_text_chunks":
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
"filter_keys":
"select id from {table_name} where workspace='{workspace}' and id in ({ids})",
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL USING DUAL
ON (a.id = '{check_id}') ON (a.id = '{check_id}')
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:1,:2,:3) INSERT(id,content,workspace) values(:1,:2,:3)
""", """,
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
USING DUAL USING DUAL
ON (a.id = '{check_id}') ON (a.id = '{check_id}')
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:1,:2,:3,:4,:5,:6,:7) """, values (:1,:2,:3,:4,:5,:6,:7) """,
# SQL for VectorStorage # SQL for VectorStorage
"entities": "entities": """SELECT name as entity_name FROM
"""SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
"relationships":
"""SELECT source_name as src_id, target_name as tgt_id FROM
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
"chunks": """SELECT id FROM
"chunks":
"""SELECT id FROM
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
# SQL for GraphStorage # SQL for GraphStorage
"has_node": "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a) MATCH (a)
WHERE a.workspace='{workspace}' AND a.name='{node_id}' WHERE a.workspace='{workspace}' AND a.name='{node_id}'
COLUMNS (a.name))""", COLUMNS (a.name))""",
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
"has_edge":
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a) -[e]-> (b) MATCH (a) -[e]-> (b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{source_node_id}' AND b.name='{target_node_id}' AND a.name='{source_node_id}' AND b.name='{target_node_id}'
COLUMNS (e.source_name,e.target_name) )""", COLUMNS (e.source_name,e.target_name) )""",
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
"node_degree":
"""SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) MATCH (a)-[e]->(b)
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{node_id}' or b.name = '{node_id}' AND a.name='{node_id}' or b.name = '{node_id}'
COLUMNS (a.name))""", COLUMNS (a.name))""",
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
"get_node":
"""SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
FROM GRAPH_TABLE (lightrag_graph FROM GRAPH_TABLE (lightrag_graph
MATCH (a) MATCH (a)
WHERE a.workspace='{workspace}' AND a.name='{node_id}' WHERE a.workspace='{workspace}' AND a.name='{node_id}'
COLUMNS (a.name) COLUMNS (a.name)
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
WHERE t2.workspace='{workspace}'""", WHERE t2.workspace='{workspace}'""",
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
"get_edge":
"""SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
FROM GRAPH_TABLE (lightrag_graph FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) MATCH (a)-[e]->(b)
@@ -659,15 +679,12 @@ SQL_TEMPLATES = {
AND a.name='{source_node_id}' and b.name = '{target_node_id}' AND a.name='{source_node_id}' and b.name = '{target_node_id}'
COLUMNS (e.id,a.name as source_id) COLUMNS (e.id,a.name as source_id)
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
"get_node_edges": """SELECT source_name,target_name
"get_node_edges":
"""SELECT source_name,target_name
FROM GRAPH_TABLE (lightrag_graph FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) MATCH (a)-[e]->(b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{source_node_id}' AND a.name='{source_node_id}'
COLUMNS (a.name as source_name,b.name as target_name))""", COLUMNS (a.name as source_name,b.name as target_name))""",
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
USING DUAL USING DUAL
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
@@ -679,5 +696,5 @@ SQL_TEMPLATES = {
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """ values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
} }

View File

@@ -38,15 +38,11 @@ from .storage import (
JsonKVStorage, JsonKVStorage,
NanoVectorDBStorage, NanoVectorDBStorage,
NetworkXStorage, NetworkXStorage,
) )
from .kg.neo4j_impl import Neo4JStorage from .kg.neo4j_impl import Neo4JStorage
from .kg.oracle_impl import ( from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
OracleKVStorage,
OracleGraphStorage,
OracleVectorDBStorage
)
# future KG integrations # future KG integrations
@@ -54,6 +50,7 @@ from .kg.oracle_impl import (
# GraphStorage as ArangoDBStorage # GraphStorage as ArangoDBStorage
# ) # )
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
return asyncio.get_event_loop() return asyncio.get_event_loop()
@@ -72,7 +69,7 @@ class LightRAG:
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
kv_storage : str = field(default="JsonKVStorage") kv_storage: str = field(default="JsonKVStorage")
vector_storage: str = field(default="NanoVectorDBStorage") vector_storage: str = field(default="NanoVectorDBStorage")
graph_storage: str = field(default="NetworkXStorage") graph_storage: str = field(default="NetworkXStorage")
@@ -134,18 +131,25 @@ class LightRAG:
# @TODO: should move all storage setup here to leverage initial start params attached to self. # @TODO: should move all storage setup here to leverage initial start params attached to self.
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage] self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage] self._get_storage_class()[self.kv_storage]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage] )
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
self.vector_storage
]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.graph_storage
]
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
self.llm_response_cache = ( self.llm_response_cache = (
self.key_string_value_json_storage_cls( self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self),embedding_func=None namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
) )
if self.enable_llm_cache if self.enable_llm_cache
else None else None
@@ -159,13 +163,19 @@ class LightRAG:
# add embedding func by walter # add embedding func by walter
#### ####
self.full_docs = self.key_string_value_json_storage_cls( self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func namespace="full_docs",
global_config=asdict(self),
embedding_func=self.embedding_func,
) )
self.text_chunks = self.key_string_value_json_storage_cls( self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func namespace="text_chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph = self.graph_storage_cls( self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func namespace="chunk_entity_relation",
global_config=asdict(self),
embedding_func=self.embedding_func,
) )
#### ####
# add embedding func by walter over # add embedding func by walter over
@@ -200,13 +210,11 @@ class LightRAG:
def _get_storage_class(self) -> Type[BaseGraphStorage]: def _get_storage_class(self) -> Type[BaseGraphStorage]:
return { return {
# kv storage # kv storage
"JsonKVStorage":JsonKVStorage, "JsonKVStorage": JsonKVStorage,
"OracleKVStorage":OracleKVStorage, "OracleKVStorage": OracleKVStorage,
# vector storage # vector storage
"NanoVectorDBStorage":NanoVectorDBStorage, "NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage":OracleVectorDBStorage, "OracleVectorDBStorage": OracleVectorDBStorage,
# graph storage # graph storage
"NetworkXStorage": NetworkXStorage, "NetworkXStorage": NetworkXStorage,
"Neo4JStorage": Neo4JStorage, "Neo4JStorage": Neo4JStorage,

View File

@@ -16,7 +16,7 @@ from .utils import (
split_string_by_multi_markers, split_string_by_multi_markers,
truncate_list_by_token_size, truncate_list_by_token_size,
process_combine_contexts, process_combine_contexts,
locate_json_string_body_from_string locate_json_string_body_from_string,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,

View File

@@ -1,22 +1,22 @@
accelerate accelerate
aioboto3
aiohttp aiohttp
# database packages
graspologic
hnswlib
nano-vectordb
neo4j
networkx
ollama
openai
oracledb
pyvis pyvis
tenacity tenacity
xxhash
# lmdeploy[all] # lmdeploy[all]
# LLM packages # LLM packages
tiktoken tiktoken
torch torch
transformers transformers
aioboto3 xxhash
ollama
openai
# database packages
graspologic
hnswlib
networkx
oracledb
nano-vectordb
neo4j