fix pre commit
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import sys, os
|
import sys
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -13,7 +13,6 @@ from lightrag import LightRAG, QueryParam
|
|||||||
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from lightrag.kg.oracle_impl import OracleDB
|
from lightrag.kg.oracle_impl import OracleDB
|
||||||
|
|
||||||
@@ -24,8 +23,6 @@ script_directory = Path(__file__).resolve().parent.parent
|
|||||||
sys.path.append(os.path.abspath(script_directory))
|
sys.path.append(os.path.abspath(script_directory))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Apply nest_asyncio to solve event loop issues
|
# Apply nest_asyncio to solve event loop issues
|
||||||
nest_asyncio.apply()
|
nest_asyncio.apply()
|
||||||
|
|
||||||
@@ -51,6 +48,7 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
|||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -80,8 +78,8 @@ async def get_embedding_dim():
|
|||||||
embedding_dim = embedding.shape[1]
|
embedding_dim = embedding.shape[1]
|
||||||
return embedding_dim
|
return embedding_dim
|
||||||
|
|
||||||
|
|
||||||
async def init():
|
async def init():
|
||||||
|
|
||||||
# Detect embedding dimension
|
# Detect embedding dimension
|
||||||
embedding_dimension = await get_embedding_dim()
|
embedding_dimension = await get_embedding_dim()
|
||||||
print(f"Detected embedding dimension: {embedding_dimension}")
|
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||||
@@ -91,36 +89,36 @@ async def init():
|
|||||||
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
||||||
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
||||||
|
|
||||||
|
oracle_db = OracleDB(
|
||||||
|
config={
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"dsn": "",
|
||||||
|
"config_dir": "",
|
||||||
|
"wallet_location": "",
|
||||||
|
"wallet_password": "",
|
||||||
|
"workspace": "",
|
||||||
|
} # specify which docs you want to store and query
|
||||||
|
)
|
||||||
|
|
||||||
oracle_db = OracleDB(config={
|
|
||||||
"user":"",
|
|
||||||
"password":"",
|
|
||||||
"dsn":"",
|
|
||||||
"config_dir":"",
|
|
||||||
"wallet_location":"",
|
|
||||||
"wallet_password":"",
|
|
||||||
"workspace":""
|
|
||||||
} # specify which docs you want to store and query
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Oracle DB tables exist, if not, tables will be created
|
# Check if Oracle DB tables exist, if not, tables will be created
|
||||||
await oracle_db.check_tables()
|
await oracle_db.check_tables()
|
||||||
# Initialize LightRAG
|
# Initialize LightRAG
|
||||||
# We use Oracle DB as the KV/vector/graph storage
|
# We use Oracle DB as the KV/vector/graph storage
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
enable_llm_cache=False,
|
enable_llm_cache=False,
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
chunk_token_size=512,
|
chunk_token_size=512,
|
||||||
llm_model_func=llm_model_func,
|
llm_model_func=llm_model_func,
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=embedding_dimension,
|
embedding_dim=embedding_dimension,
|
||||||
max_token_size=512,
|
max_token_size=512,
|
||||||
func=embedding_func,
|
func=embedding_func,
|
||||||
),
|
),
|
||||||
graph_storage = "OracleGraphStorage",
|
graph_storage="OracleGraphStorage",
|
||||||
kv_storage="OracleKVStorage",
|
kv_storage="OracleKVStorage",
|
||||||
vector_storage="OracleVectorDBStorage"
|
vector_storage="OracleVectorDBStorage",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||||
rag.graph_storage_cls.db = oracle_db
|
rag.graph_storage_cls.db = oracle_db
|
||||||
@@ -129,6 +127,7 @@ async def init():
|
|||||||
|
|
||||||
return rag
|
return rag
|
||||||
|
|
||||||
|
|
||||||
# Data models
|
# Data models
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +151,7 @@ class Response(BaseModel):
|
|||||||
|
|
||||||
rag = None # 定义为全局对象
|
rag = None # 定义为全局对象
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global rag
|
global rag
|
||||||
@@ -160,18 +160,21 @@ async def lifespan(app: FastAPI):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
|
app = FastAPI(
|
||||||
|
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query", response_model=Response)
|
@app.post("/query", response_model=Response)
|
||||||
async def query_endpoint(request: QueryRequest):
|
async def query_endpoint(request: QueryRequest):
|
||||||
try:
|
try:
|
||||||
# loop = asyncio.get_event_loop()
|
# loop = asyncio.get_event_loop()
|
||||||
result = await rag.aquery(
|
result = await rag.aquery(
|
||||||
request.query,
|
request.query,
|
||||||
param=QueryParam(
|
param=QueryParam(
|
||||||
mode=request.mode, only_need_context=request.only_need_context
|
mode=request.mode, only_need_context=request.only_need_context
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return Response(status="success", data=result)
|
return Response(status="success", data=result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -234,4 +237,4 @@ if __name__ == "__main__":
|
|||||||
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
|
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
|
||||||
|
|
||||||
# 4. Health check:
|
# 4. Health check:
|
||||||
# curl -X GET "http://127.0.0.1:8020/health"
|
# curl -X GET "http://127.0.0.1:8020/health"
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
import sys, os
|
import sys
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
|
||||||
from lightrag.kg.oracle_impl import OracleDB
|
from lightrag.kg.oracle_impl import OracleDB
|
||||||
|
|
||||||
print(os.getcwd())
|
print(os.getcwd())
|
||||||
@@ -25,6 +25,7 @@ EMBEDMODEL = "cohere.embed-multilingual-v3.0"
|
|||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -66,22 +67,21 @@ async def main():
|
|||||||
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
||||||
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
||||||
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
||||||
oracle_db = OracleDB(config={
|
oracle_db = OracleDB(
|
||||||
"user":"username",
|
config={
|
||||||
"password":"xxxxxxxxx",
|
"user": "username",
|
||||||
"dsn":"xxxxxxx_medium",
|
"password": "xxxxxxxxx",
|
||||||
"config_dir":"dir/path/to/oracle/config",
|
"dsn": "xxxxxxx_medium",
|
||||||
"wallet_location":"dir/path/to/oracle/wallet",
|
"config_dir": "dir/path/to/oracle/config",
|
||||||
"wallet_password":"xxxxxxxxx",
|
"wallet_location": "dir/path/to/oracle/wallet",
|
||||||
"workspace":"company" # specify which docs you want to store and query
|
"wallet_password": "xxxxxxxxx",
|
||||||
|
"workspace": "company", # specify which docs you want to store and query
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Check if Oracle DB tables exist, if not, tables will be created
|
# Check if Oracle DB tables exist, if not, tables will be created
|
||||||
await oracle_db.check_tables()
|
await oracle_db.check_tables()
|
||||||
|
|
||||||
|
|
||||||
# Initialize LightRAG
|
# Initialize LightRAG
|
||||||
# We use Oracle DB as the KV/vector/graph storage
|
# We use Oracle DB as the KV/vector/graph storage
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
@@ -93,10 +93,10 @@ async def main():
|
|||||||
embedding_dim=embedding_dimension,
|
embedding_dim=embedding_dimension,
|
||||||
max_token_size=512,
|
max_token_size=512,
|
||||||
func=embedding_func,
|
func=embedding_func,
|
||||||
),
|
),
|
||||||
graph_storage = "OracleGraphStorage",
|
graph_storage="OracleGraphStorage",
|
||||||
kv_storage="OracleKVStorage",
|
kv_storage="OracleKVStorage",
|
||||||
vector_storage="OracleVectorDBStorage"
|
vector_storage="OracleVectorDBStorage",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||||
@@ -106,18 +106,23 @@ async def main():
|
|||||||
|
|
||||||
# Extract and Insert into LightRAG storage
|
# Extract and Insert into LightRAG storage
|
||||||
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
||||||
await rag.ainsert(f.read())
|
await rag.ainsert(f.read())
|
||||||
|
|
||||||
# Perform search in different modes
|
# Perform search in different modes
|
||||||
modes = ["naive", "local", "global", "hybrid"]
|
modes = ["naive", "local", "global", "hybrid"]
|
||||||
for mode in modes:
|
for mode in modes:
|
||||||
print("="*20, mode, "="*20)
|
print("=" * 20, mode, "=" * 20)
|
||||||
print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode)))
|
print(
|
||||||
print("-"*100, "\n")
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode=mode),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("-" * 100, "\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@@ -60,6 +60,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
|
||||||
async def all_keys(self) -> list[str]:
|
async def all_keys(self) -> list[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -85,6 +86,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace):
|
class BaseGraphStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc = None
|
embedding_func: EmbeddingFunc = None
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -1,9 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
#import html
|
|
||||||
#import os
|
# import html
|
||||||
|
# import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, cast
|
from typing import Union
|
||||||
import networkx as nx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import array
|
import array
|
||||||
|
|
||||||
@@ -16,8 +16,9 @@ from ..base import (
|
|||||||
|
|
||||||
import oracledb
|
import oracledb
|
||||||
|
|
||||||
|
|
||||||
class OracleDB:
|
class OracleDB:
|
||||||
def __init__(self,config,**kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
self.host = config.get("host", None)
|
self.host = config.get("host", None)
|
||||||
self.port = config.get("port", None)
|
self.port = config.get("port", None)
|
||||||
self.user = config.get("user", None)
|
self.user = config.get("user", None)
|
||||||
@@ -32,21 +33,21 @@ class OracleDB:
|
|||||||
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
||||||
if self.user is None or self.password is None:
|
if self.user is None or self.password is None:
|
||||||
raise ValueError("Missing database user or password in addon_params")
|
raise ValueError("Missing database user or password in addon_params")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
oracledb.defaults.fetch_lobs = False
|
oracledb.defaults.fetch_lobs = False
|
||||||
|
|
||||||
self.pool = oracledb.create_pool_async(
|
self.pool = oracledb.create_pool_async(
|
||||||
user = self.user,
|
user=self.user,
|
||||||
password = self.password,
|
password=self.password,
|
||||||
dsn = self.dsn,
|
dsn=self.dsn,
|
||||||
config_dir = self.config_dir,
|
config_dir=self.config_dir,
|
||||||
wallet_location = self.wallet_location,
|
wallet_location=self.wallet_location,
|
||||||
wallet_password = self.wallet_password,
|
wallet_password=self.wallet_password,
|
||||||
min = 1,
|
min=1,
|
||||||
max = self.max,
|
max=self.max,
|
||||||
increment = self.increment
|
increment=self.increment,
|
||||||
)
|
)
|
||||||
logger.info(f"Connected to Oracle database at {self.dsn}")
|
logger.info(f"Connected to Oracle database at {self.dsn}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
||||||
@@ -90,12 +91,14 @@ class OracleDB:
|
|||||||
arraysize=cursor.arraysize,
|
arraysize=cursor.arraysize,
|
||||||
outconverter=self.numpy_converter_out,
|
outconverter=self.numpy_converter_out,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_tables(self):
|
async def check_tables(self):
|
||||||
for k,v in TABLES.items():
|
for k, v in TABLES.items():
|
||||||
try:
|
try:
|
||||||
if k.lower() == "lightrag_graph":
|
if k.lower() == "lightrag_graph":
|
||||||
await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only")
|
await self.query(
|
||||||
|
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await self.query("SELECT 1 FROM {k}".format(k=k))
|
await self.query("SELECT 1 FROM {k}".format(k=k))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -108,12 +111,11 @@ class OracleDB:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create table {k} in Oracle database")
|
logger.error(f"Failed to create table {k} in Oracle database")
|
||||||
logger.error(f"Oracle database error: {e}")
|
logger.error(f"Oracle database error: {e}")
|
||||||
|
|
||||||
logger.info(f"Finished check all tables in Oracle database")
|
logger.info("Finished check all tables in Oracle database")
|
||||||
|
|
||||||
|
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
||||||
async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]:
|
async with self.pool.acquire() as connection:
|
||||||
async with self.pool.acquire() as connection:
|
|
||||||
connection.inputtypehandler = self.input_type_handler
|
connection.inputtypehandler = self.input_type_handler
|
||||||
connection.outputtypehandler = self.output_type_handler
|
connection.outputtypehandler = self.output_type_handler
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
@@ -136,9 +138,9 @@ class OracleDB:
|
|||||||
data = dict(zip(columns, row))
|
data = dict(zip(columns, row))
|
||||||
else:
|
else:
|
||||||
data = None
|
data = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def execute(self,sql: str, data: list = None):
|
async def execute(self, sql: str, data: list = None):
|
||||||
# logger.info("go into OracleDB execute method")
|
# logger.info("go into OracleDB execute method")
|
||||||
try:
|
try:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
@@ -148,58 +150,63 @@ class OracleDB:
|
|||||||
if data is None:
|
if data is None:
|
||||||
await cursor.execute(sql)
|
await cursor.execute(sql)
|
||||||
else:
|
else:
|
||||||
#print(data)
|
# print(data)
|
||||||
#print(sql)
|
# print(sql)
|
||||||
await cursor.execute(sql,data)
|
await cursor.execute(sql, data)
|
||||||
await connection.commit()
|
await connection.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Oracle database error: {e}")
|
logger.error(f"Oracle database error: {e}")
|
||||||
print(sql)
|
print(sql)
|
||||||
print(data)
|
print(data)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleKVStorage(BaseKVStorage):
|
class OracleKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
# should pass db object to self.db
|
# should pass db object to self.db
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||||
"""根据 id 获取 doc_full 数据."""
|
"""根据 id 获取 doc_full 数据."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id)
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
|
||||||
#print("get_by_id:"+SQL)
|
workspace=self.db.workspace, id=id
|
||||||
res = await self.db.query(SQL)
|
)
|
||||||
|
# print("get_by_id:"+SQL)
|
||||||
|
res = await self.db.query(SQL)
|
||||||
if res:
|
if res:
|
||||||
data = res #{"data":res}
|
data = res # {"data":res}
|
||||||
#print (data)
|
# print (data)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]:
|
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||||
"""根据 id 获取 doc_chunks 数据"""
|
"""根据 id 获取 doc_chunks 数据"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace,
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids]))
|
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
|
||||||
#print("get_by_ids:"+SQL)
|
)
|
||||||
res = await self.db.query(SQL,multirows=True)
|
# print("get_by_ids:"+SQL)
|
||||||
|
res = await self.db.query(SQL, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = res # [{"data":i} for i in res]
|
data = res # [{"data":i} for i in res]
|
||||||
#print(data)
|
# print(data)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||||
"""过滤掉重复内容"""
|
"""过滤掉重复内容"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
workspace=self.db.workspace,
|
table_name=N_T[self.namespace],
|
||||||
ids=",".join([f"'{k}'" for k in keys]))
|
workspace=self.db.workspace,
|
||||||
res = await self.db.query(SQL,multirows=True)
|
ids=",".join([f"'{k}'" for k in keys]),
|
||||||
|
)
|
||||||
|
res = await self.db.query(SQL, multirows=True)
|
||||||
data = None
|
data = None
|
||||||
if res:
|
if res:
|
||||||
exist_keys = [key["id"] for key in res]
|
exist_keys = [key["id"] for key in res]
|
||||||
@@ -208,14 +215,13 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
exist_keys = []
|
exist_keys = []
|
||||||
data = set([s for s in keys if s not in exist_keys])
|
data = set([s for s in keys if s not in exist_keys])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
################ INSERT METHODS ################
|
################ INSERT METHODS ################
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
self._data.update(left_data)
|
self._data.update(left_data)
|
||||||
#print(self._data)
|
# print(self._data)
|
||||||
#values = []
|
# values = []
|
||||||
if self.namespace == "text_chunks":
|
if self.namespace == "text_chunks":
|
||||||
list_data = [
|
list_data = [
|
||||||
{
|
{
|
||||||
@@ -226,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
]
|
]
|
||||||
contents = [v["content"] for v in data.values()]
|
contents = [v["content"] for v in data.values()]
|
||||||
batches = [
|
batches = [
|
||||||
contents[i: i + self._max_batch_size]
|
contents[i : i + self._max_batch_size]
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
]
|
]
|
||||||
embeddings_list = await asyncio.gather(
|
embeddings_list = await asyncio.gather(
|
||||||
@@ -235,42 +241,45 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
for i, d in enumerate(list_data):
|
for i, d in enumerate(list_data):
|
||||||
d["__vector__"] = embeddings[i]
|
d["__vector__"] = embeddings[i]
|
||||||
#print(list_data)
|
# print(list_data)
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
merge_sql = SQL_TEMPLATES["merge_chunk"].format(
|
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
|
||||||
check_id=item["__id__"]
|
|
||||||
)
|
|
||||||
|
|
||||||
values = [item["__id__"], item["content"], self.db.workspace, item["tokens"],
|
values = [
|
||||||
item["chunk_order_index"], item["full_doc_id"], item["__vector__"]]
|
item["__id__"],
|
||||||
#print(merge_sql)
|
item["content"],
|
||||||
|
self.db.workspace,
|
||||||
|
item["tokens"],
|
||||||
|
item["chunk_order_index"],
|
||||||
|
item["full_doc_id"],
|
||||||
|
item["__vector__"],
|
||||||
|
]
|
||||||
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, values)
|
await self.db.execute(merge_sql, values)
|
||||||
|
|
||||||
if self.namespace == "full_docs":
|
if self.namespace == "full_docs":
|
||||||
for k, v in self._data.items():
|
for k, v in self._data.items():
|
||||||
#values.clear()
|
# values.clear()
|
||||||
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
||||||
check_id=k,
|
check_id=k,
|
||||||
)
|
)
|
||||||
values = [k, self._data[k]["content"], self.db.workspace]
|
values = [k, self._data[k]["content"], self.db.workspace]
|
||||||
#print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, values)
|
await self.db.execute(merge_sql, values)
|
||||||
return left_data
|
return left_data
|
||||||
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
if self.namespace in ["full_docs", "text_chunks"]:
|
if self.namespace in ["full_docs", "text_chunks"]:
|
||||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = 0.2
|
cosine_better_than_threshold: float = 0.2
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
"""向向量数据库中插入数据"""
|
"""向向量数据库中插入数据"""
|
||||||
pass
|
pass
|
||||||
@@ -278,53 +287,51 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||||
"""从向量数据库中查询数据"""
|
"""从向量数据库中查询数据"""
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
# 转换精度
|
# 转换精度
|
||||||
dtype = str(embedding.dtype).upper()
|
dtype = str(embedding.dtype).upper()
|
||||||
dimension = embedding.shape[0]
|
dimension = embedding.shape[0]
|
||||||
embedding_string = ', '.join(map(str, embedding.tolist()))
|
embedding_string = ", ".join(map(str, embedding.tolist()))
|
||||||
|
|
||||||
SQL = SQL_TEMPLATES[self.namespace].format(
|
SQL = SQL_TEMPLATES[self.namespace].format(
|
||||||
embedding_string=embedding_string,
|
embedding_string=embedding_string,
|
||||||
dimension=dimension,
|
dimension=dimension,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
workspace=self.db.workspace,
|
workspace=self.db.workspace,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
better_than_threshold=self.cosine_better_than_threshold,
|
better_than_threshold=self.cosine_better_than_threshold,
|
||||||
)
|
)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
results = await self.db.query(SQL, multirows=True)
|
results = await self.db.query(SQL, multirows=True)
|
||||||
#print("vector search result:",results)
|
# print("vector search result:",results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleGraphStorage(BaseGraphStorage):
|
class OracleGraphStorage(BaseGraphStorage):
|
||||||
"""基于Oracle的图存储模块"""
|
"""基于Oracle的图存储模块"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""从graphml文件加载图"""
|
"""从graphml文件加载图"""
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
|
||||||
#################### insert method ################
|
#################### insert method ################
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||||
"""插入或更新节点"""
|
"""插入或更新节点"""
|
||||||
#print("go into upsert node method")
|
# print("go into upsert node method")
|
||||||
entity_name = node_id
|
entity_name = node_id
|
||||||
entity_type = node_data["entity_type"]
|
entity_type = node_data["entity_type"]
|
||||||
description = node_data["description"]
|
description = node_data["description"]
|
||||||
source_id = node_data["source_id"]
|
source_id = node_data["source_id"]
|
||||||
content = entity_name+description
|
content = entity_name + description
|
||||||
contents = [content]
|
contents = [content]
|
||||||
batches = [
|
batches = [
|
||||||
contents[i: i + self._max_batch_size]
|
contents[i : i + self._max_batch_size]
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
]
|
]
|
||||||
embeddings_list = await asyncio.gather(
|
embeddings_list = await asyncio.gather(
|
||||||
@@ -333,27 +340,38 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
content_vector = embeddings[0]
|
content_vector = embeddings[0]
|
||||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
||||||
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
|
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
|
||||||
)
|
)
|
||||||
#print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
|
await self.db.execute(
|
||||||
#self._graph.add_node(node_id, **node_data)
|
merge_sql,
|
||||||
|
[
|
||||||
|
self.db.workspace,
|
||||||
|
entity_name,
|
||||||
|
entity_type,
|
||||||
|
description,
|
||||||
|
source_id,
|
||||||
|
content,
|
||||||
|
content_vector,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# self._graph.add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
):
|
||||||
"""插入或更新边"""
|
"""插入或更新边"""
|
||||||
#print("go into upsert edge method")
|
# print("go into upsert edge method")
|
||||||
source_name = source_node_id
|
source_name = source_node_id
|
||||||
target_name = target_node_id
|
target_name = target_node_id
|
||||||
weight = edge_data["weight"]
|
weight = edge_data["weight"]
|
||||||
keywords = edge_data["keywords"]
|
keywords = edge_data["keywords"]
|
||||||
description = edge_data["description"]
|
description = edge_data["description"]
|
||||||
source_chunk_id = edge_data["source_id"]
|
source_chunk_id = edge_data["source_id"]
|
||||||
content = keywords+source_name+target_name+description
|
content = keywords + source_name + target_name + description
|
||||||
contents = [content]
|
contents = [content]
|
||||||
batches = [
|
batches = [
|
||||||
contents[i: i + self._max_batch_size]
|
contents[i : i + self._max_batch_size]
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
]
|
]
|
||||||
embeddings_list = await asyncio.gather(
|
embeddings_list = await asyncio.gather(
|
||||||
@@ -362,11 +380,27 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
content_vector = embeddings[0]
|
content_vector = embeddings[0]
|
||||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
||||||
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
|
workspace=self.db.workspace,
|
||||||
|
source_name=source_name,
|
||||||
|
target_name=target_name,
|
||||||
|
source_chunk_id=source_chunk_id,
|
||||||
)
|
)
|
||||||
#print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
|
await self.db.execute(
|
||||||
#self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
merge_sql,
|
||||||
|
[
|
||||||
|
self.db.workspace,
|
||||||
|
source_name,
|
||||||
|
target_name,
|
||||||
|
weight,
|
||||||
|
keywords,
|
||||||
|
description,
|
||||||
|
source_chunk_id,
|
||||||
|
content,
|
||||||
|
content_vector,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||||
"""为节点生成向量"""
|
"""为节点生成向量"""
|
||||||
@@ -386,99 +420,109 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||||
return embeddings, nodes_ids
|
return embeddings, nodes_ids
|
||||||
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""写入graphhml图文件"""
|
"""写入graphhml图文件"""
|
||||||
logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!")
|
logger.info(
|
||||||
|
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
||||||
|
)
|
||||||
|
|
||||||
#################### query method #################
|
#################### query method #################
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""根据节点id检查节点是否存在"""
|
"""根据节点id检查节点是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id)
|
SQL = SQL_TEMPLATES["has_node"].format(
|
||||||
# print(SQL)
|
workspace=self.db.workspace, node_id=node_id
|
||||||
#print(self.db.workspace, node_id)
|
)
|
||||||
|
# print(SQL)
|
||||||
|
# print(self.db.workspace, node_id)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL)
|
||||||
if res:
|
if res:
|
||||||
#print("Node exist!",res)
|
# print("Node exist!",res)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
#print("Node not exist!")
|
# print("Node not exist!")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""根据源和目标节点id检查边是否存在"""
|
"""根据源和目标节点id检查边是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace,
|
SQL = SQL_TEMPLATES["has_edge"].format(
|
||||||
source_node_id=source_node_id,
|
workspace=self.db.workspace,
|
||||||
target_node_id=target_node_id)
|
source_node_id=source_node_id,
|
||||||
|
target_node_id=target_node_id,
|
||||||
|
)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL)
|
||||||
if res:
|
if res:
|
||||||
#print("Edge exist!",res)
|
# print("Edge exist!",res)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
#print("Edge not exist!")
|
# print("Edge not exist!")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""根据节点id获取节点的度"""
|
"""根据节点id获取节点的度"""
|
||||||
SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id)
|
SQL = SQL_TEMPLATES["node_degree"].format(
|
||||||
|
workspace=self.db.workspace, node_id=node_id
|
||||||
|
)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL)
|
||||||
if res:
|
if res:
|
||||||
#print("Node degree",res["degree"])
|
# print("Node degree",res["degree"])
|
||||||
return res["degree"]
|
return res["degree"]
|
||||||
else:
|
else:
|
||||||
#print("Edge not exist!")
|
# print("Edge not exist!")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
"""根据源和目标节点id获取边的度"""
|
"""根据源和目标节点id获取边的度"""
|
||||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||||
#print("Edge degree",degree)
|
# print("Edge degree",degree)
|
||||||
return degree
|
return degree
|
||||||
|
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
"""根据节点id获取节点数据"""
|
"""根据节点id获取节点数据"""
|
||||||
SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id)
|
SQL = SQL_TEMPLATES["get_node"].format(
|
||||||
|
workspace=self.db.workspace, node_id=node_id
|
||||||
|
)
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL)
|
||||||
if res:
|
if res:
|
||||||
#print("Get node!",self.db.workspace, node_id,res)
|
# print("Get node!",self.db.workspace, node_id,res)
|
||||||
return res
|
return res
|
||||||
else:
|
else:
|
||||||
#print("Can't get node!",self.db.workspace, node_id)
|
# print("Can't get node!",self.db.workspace, node_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
"""根据源和目标节点id获取边"""
|
"""根据源和目标节点id获取边"""
|
||||||
SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace,
|
SQL = SQL_TEMPLATES["get_edge"].format(
|
||||||
source_node_id=source_node_id,
|
workspace=self.db.workspace,
|
||||||
target_node_id=target_node_id)
|
source_node_id=source_node_id,
|
||||||
|
target_node_id=target_node_id,
|
||||||
|
)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL)
|
||||||
if res:
|
if res:
|
||||||
#print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
||||||
return res
|
return res
|
||||||
else:
|
else:
|
||||||
#print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str):
|
async def get_node_edges(self, source_node_id: str):
|
||||||
"""根据节点id获取节点的所有边"""
|
"""根据节点id获取节点的所有边"""
|
||||||
if await self.has_node(source_node_id):
|
if await self.has_node(source_node_id):
|
||||||
SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace,
|
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
||||||
source_node_id=source_node_id)
|
workspace=self.db.workspace, source_node_id=source_node_id
|
||||||
|
)
|
||||||
res = await self.db.query(sql=SQL, multirows=True)
|
res = await self.db.query(sql=SQL, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = [(i["source_name"],i["target_name"]) for i in res]
|
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||||
#print("Get node edge!",self.db.workspace, source_node_id,data)
|
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
#print("Node Edge not exist!",self.db.workspace, source_node_id)
|
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -487,12 +531,12 @@ N_T = {
|
|||||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
"entities": "LIGHTRAG_GRAPH_NODES",
|
"entities": "LIGHTRAG_GRAPH_NODES",
|
||||||
"relationships": "LIGHTRAG_GRAPH_EDGES"
|
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
||||||
}
|
}
|
||||||
|
|
||||||
TABLES = {
|
TABLES = {
|
||||||
"LIGHTRAG_DOC_FULL":
|
"LIGHTRAG_DOC_FULL": {
|
||||||
{"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL (
|
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||||
id varchar(256)PRIMARY KEY,
|
id varchar(256)PRIMARY KEY,
|
||||||
workspace varchar(1024),
|
workspace varchar(1024),
|
||||||
doc_name varchar(1024),
|
doc_name varchar(1024),
|
||||||
@@ -500,61 +544,63 @@ TABLES = {
|
|||||||
meta JSON,
|
meta JSON,
|
||||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
updatetime TIMESTAMP DEFAULT NULL
|
updatetime TIMESTAMP DEFAULT NULL
|
||||||
)"""},
|
)"""
|
||||||
|
},
|
||||||
"LIGHTRAG_DOC_CHUNKS":
|
"LIGHTRAG_DOC_CHUNKS": {
|
||||||
{"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||||
id varchar(256) PRIMARY KEY,
|
id varchar(256) PRIMARY KEY,
|
||||||
workspace varchar(1024),
|
workspace varchar(1024),
|
||||||
full_doc_id varchar(256),
|
full_doc_id varchar(256),
|
||||||
chunk_order_index NUMBER,
|
chunk_order_index NUMBER,
|
||||||
tokens NUMBER,
|
tokens NUMBER,
|
||||||
content CLOB,
|
content CLOB,
|
||||||
content_vector VECTOR,
|
content_vector VECTOR,
|
||||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
updatetime TIMESTAMP DEFAULT NULL
|
updatetime TIMESTAMP DEFAULT NULL
|
||||||
)"""},
|
)"""
|
||||||
|
},
|
||||||
"LIGHTRAG_GRAPH_NODES":
|
"LIGHTRAG_GRAPH_NODES": {
|
||||||
{"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
||||||
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||||
workspace varchar(1024),
|
workspace varchar(1024),
|
||||||
name varchar(2048),
|
name varchar(2048),
|
||||||
entity_type varchar(1024),
|
entity_type varchar(1024),
|
||||||
description CLOB,
|
description CLOB,
|
||||||
source_chunk_id varchar(256),
|
source_chunk_id varchar(256),
|
||||||
content CLOB,
|
content CLOB,
|
||||||
content_vector VECTOR,
|
content_vector VECTOR,
|
||||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
updatetime TIMESTAMP DEFAULT NULL
|
updatetime TIMESTAMP DEFAULT NULL
|
||||||
)"""},
|
)"""
|
||||||
"LIGHTRAG_GRAPH_EDGES":
|
},
|
||||||
{"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
"LIGHTRAG_GRAPH_EDGES": {
|
||||||
|
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
||||||
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||||
workspace varchar(1024),
|
workspace varchar(1024),
|
||||||
source_name varchar(2048),
|
source_name varchar(2048),
|
||||||
target_name varchar(2048),
|
target_name varchar(2048),
|
||||||
weight NUMBER,
|
weight NUMBER,
|
||||||
keywords CLOB,
|
keywords CLOB,
|
||||||
description CLOB,
|
description CLOB,
|
||||||
source_chunk_id varchar(256),
|
source_chunk_id varchar(256),
|
||||||
content CLOB,
|
content CLOB,
|
||||||
content_vector VECTOR,
|
content_vector VECTOR,
|
||||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
updatetime TIMESTAMP DEFAULT NULL
|
updatetime TIMESTAMP DEFAULT NULL
|
||||||
)"""},
|
)"""
|
||||||
"LIGHTRAG_LLM_CACHE":
|
},
|
||||||
{"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE (
|
"LIGHTRAG_LLM_CACHE": {
|
||||||
|
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
|
||||||
id varchar(256) PRIMARY KEY,
|
id varchar(256) PRIMARY KEY,
|
||||||
send clob,
|
send clob,
|
||||||
return clob,
|
return clob,
|
||||||
model varchar(1024),
|
model varchar(1024),
|
||||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
updatetime TIMESTAMP DEFAULT NULL
|
updatetime TIMESTAMP DEFAULT NULL
|
||||||
)"""},
|
)"""
|
||||||
|
},
|
||||||
"LIGHTRAG_GRAPH":
|
"LIGHTRAG_GRAPH": {
|
||||||
{"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
|
"ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
|
||||||
VERTEX TABLES (
|
VERTEX TABLES (
|
||||||
lightrag_graph_nodes KEY (id)
|
lightrag_graph_nodes KEY (id)
|
||||||
LABEL entity
|
LABEL entity
|
||||||
@@ -565,93 +611,67 @@ TABLES = {
|
|||||||
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
||||||
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
||||||
LABEL has_relation
|
LABEL has_relation
|
||||||
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
||||||
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""},
|
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
SQL_TEMPLATES = {
|
SQL_TEMPLATES = {
|
||||||
# SQL for KVStorage
|
# SQL for KVStorage
|
||||||
"get_by_id_full_docs":
|
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
|
||||||
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
|
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
|
||||||
|
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
||||||
"get_by_id_text_chunks":
|
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
||||||
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
|
"filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
||||||
|
|
||||||
"get_by_ids_full_docs":
|
|
||||||
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
|
||||||
|
|
||||||
"get_by_ids_text_chunks":
|
|
||||||
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
|
||||||
|
|
||||||
"filter_keys":
|
|
||||||
"select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
|
||||||
|
|
||||||
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.id = '{check_id}')
|
ON (a.id = '{check_id}')
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(id,content,workspace) values(:1,:2,:3)
|
INSERT(id,content,workspace) values(:1,:2,:3)
|
||||||
""",
|
""",
|
||||||
|
|
||||||
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.id = '{check_id}')
|
ON (a.id = '{check_id}')
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
||||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
values (:1,:2,:3,:4,:5,:6,:7) """,
|
||||||
|
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"entities":
|
"entities": """SELECT name as entity_name FROM
|
||||||
"""SELECT name as entity_name FROM
|
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||||
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
||||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
|
||||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||||
|
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
||||||
"relationships":
|
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||||
"""SELECT source_name as src_id, target_name as tgt_id FROM
|
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
||||||
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
|
||||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
|
||||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||||
|
"chunks": """SELECT id FROM
|
||||||
"chunks":
|
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||||
"""SELECT id FROM
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
||||||
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
|
||||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||||
|
|
||||||
# SQL for GraphStorage
|
# SQL for GraphStorage
|
||||||
"has_node":
|
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||||
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
|
|
||||||
MATCH (a)
|
MATCH (a)
|
||||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
||||||
COLUMNS (a.name))""",
|
COLUMNS (a.name))""",
|
||||||
|
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||||
"has_edge":
|
|
||||||
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
|
|
||||||
MATCH (a) -[e]-> (b)
|
MATCH (a) -[e]-> (b)
|
||||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||||
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
||||||
COLUMNS (e.source_name,e.target_name) )""",
|
COLUMNS (e.source_name,e.target_name) )""",
|
||||||
|
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
||||||
"node_degree":
|
|
||||||
"""SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
|
||||||
MATCH (a)-[e]->(b)
|
MATCH (a)-[e]->(b)
|
||||||
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||||
AND a.name='{node_id}' or b.name = '{node_id}'
|
AND a.name='{node_id}' or b.name = '{node_id}'
|
||||||
COLUMNS (a.name))""",
|
COLUMNS (a.name))""",
|
||||||
|
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
||||||
"get_node":
|
|
||||||
"""SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
|
||||||
FROM GRAPH_TABLE (lightrag_graph
|
FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)
|
MATCH (a)
|
||||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
||||||
COLUMNS (a.name)
|
COLUMNS (a.name)
|
||||||
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
||||||
WHERE t2.workspace='{workspace}'""",
|
WHERE t2.workspace='{workspace}'""",
|
||||||
|
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
||||||
"get_edge":
|
|
||||||
"""SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
|
||||||
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
||||||
FROM GRAPH_TABLE (lightrag_graph
|
FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)-[e]->(b)
|
MATCH (a)-[e]->(b)
|
||||||
@@ -659,15 +679,12 @@ SQL_TEMPLATES = {
|
|||||||
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
||||||
COLUMNS (e.id,a.name as source_id)
|
COLUMNS (e.id,a.name as source_id)
|
||||||
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
||||||
|
"get_node_edges": """SELECT source_name,target_name
|
||||||
"get_node_edges":
|
|
||||||
"""SELECT source_name,target_name
|
|
||||||
FROM GRAPH_TABLE (lightrag_graph
|
FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)-[e]->(b)
|
MATCH (a)-[e]->(b)
|
||||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||||
AND a.name='{source_node_id}'
|
AND a.name='{source_node_id}'
|
||||||
COLUMNS (a.name as source_name,b.name as target_name))""",
|
COLUMNS (a.name as source_name,b.name as target_name))""",
|
||||||
|
|
||||||
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
||||||
@@ -679,5 +696,5 @@ SQL_TEMPLATES = {
|
|||||||
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
||||||
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """
|
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
|
||||||
}
|
}
|
||||||
|
@@ -38,15 +38,11 @@ from .storage import (
|
|||||||
JsonKVStorage,
|
JsonKVStorage,
|
||||||
NanoVectorDBStorage,
|
NanoVectorDBStorage,
|
||||||
NetworkXStorage,
|
NetworkXStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .kg.neo4j_impl import Neo4JStorage
|
from .kg.neo4j_impl import Neo4JStorage
|
||||||
|
|
||||||
from .kg.oracle_impl import (
|
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
||||||
OracleKVStorage,
|
|
||||||
OracleGraphStorage,
|
|
||||||
OracleVectorDBStorage
|
|
||||||
)
|
|
||||||
|
|
||||||
# future KG integrations
|
# future KG integrations
|
||||||
|
|
||||||
@@ -54,6 +50,7 @@ from .kg.oracle_impl import (
|
|||||||
# GraphStorage as ArangoDBStorage
|
# GraphStorage as ArangoDBStorage
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
try:
|
try:
|
||||||
return asyncio.get_event_loop()
|
return asyncio.get_event_loop()
|
||||||
@@ -72,7 +69,7 @@ class LightRAG:
|
|||||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_storage : str = field(default="JsonKVStorage")
|
kv_storage: str = field(default="JsonKVStorage")
|
||||||
vector_storage: str = field(default="NanoVectorDBStorage")
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
||||||
graph_storage: str = field(default="NetworkXStorage")
|
graph_storage: str = field(default="NetworkXStorage")
|
||||||
|
|
||||||
@@ -115,7 +112,7 @@ class LightRAG:
|
|||||||
|
|
||||||
# storage
|
# storage
|
||||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
enable_llm_cache: bool = True
|
enable_llm_cache: bool = True
|
||||||
|
|
||||||
# extension
|
# extension
|
||||||
@@ -134,18 +131,25 @@ class LightRAG:
|
|||||||
|
|
||||||
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
||||||
|
|
||||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage]
|
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
||||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage]
|
self._get_storage_class()[self.kv_storage]
|
||||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage]
|
)
|
||||||
|
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
||||||
|
self.vector_storage
|
||||||
|
]
|
||||||
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
||||||
|
self.graph_storage
|
||||||
|
]
|
||||||
|
|
||||||
if not os.path.exists(self.working_dir):
|
if not os.path.exists(self.working_dir):
|
||||||
logger.info(f"Creating working directory {self.working_dir}")
|
logger.info(f"Creating working directory {self.working_dir}")
|
||||||
os.makedirs(self.working_dir)
|
os.makedirs(self.working_dir)
|
||||||
|
|
||||||
|
|
||||||
self.llm_response_cache = (
|
self.llm_response_cache = (
|
||||||
self.key_string_value_json_storage_cls(
|
self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache", global_config=asdict(self),embedding_func=None
|
namespace="llm_response_cache",
|
||||||
|
global_config=asdict(self),
|
||||||
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
if self.enable_llm_cache
|
if self.enable_llm_cache
|
||||||
else None
|
else None
|
||||||
@@ -159,13 +163,19 @@ class LightRAG:
|
|||||||
# add embedding func by walter
|
# add embedding func by walter
|
||||||
####
|
####
|
||||||
self.full_docs = self.key_string_value_json_storage_cls(
|
self.full_docs = self.key_string_value_json_storage_cls(
|
||||||
namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func
|
namespace="full_docs",
|
||||||
|
global_config=asdict(self),
|
||||||
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||||
namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func
|
namespace="text_chunks",
|
||||||
|
global_config=asdict(self),
|
||||||
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||||
namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func
|
namespace="chunk_entity_relation",
|
||||||
|
global_config=asdict(self),
|
||||||
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
####
|
####
|
||||||
# add embedding func by walter over
|
# add embedding func by walter over
|
||||||
@@ -200,13 +210,11 @@ class LightRAG:
|
|||||||
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
||||||
return {
|
return {
|
||||||
# kv storage
|
# kv storage
|
||||||
"JsonKVStorage":JsonKVStorage,
|
"JsonKVStorage": JsonKVStorage,
|
||||||
"OracleKVStorage":OracleKVStorage,
|
"OracleKVStorage": OracleKVStorage,
|
||||||
|
|
||||||
# vector storage
|
# vector storage
|
||||||
"NanoVectorDBStorage":NanoVectorDBStorage,
|
"NanoVectorDBStorage": NanoVectorDBStorage,
|
||||||
"OracleVectorDBStorage":OracleVectorDBStorage,
|
"OracleVectorDBStorage": OracleVectorDBStorage,
|
||||||
|
|
||||||
# graph storage
|
# graph storage
|
||||||
"NetworkXStorage": NetworkXStorage,
|
"NetworkXStorage": NetworkXStorage,
|
||||||
"Neo4JStorage": Neo4JStorage,
|
"Neo4JStorage": Neo4JStorage,
|
||||||
|
@@ -16,7 +16,7 @@ from .utils import (
|
|||||||
split_string_by_multi_markers,
|
split_string_by_multi_markers,
|
||||||
truncate_list_by_token_size,
|
truncate_list_by_token_size,
|
||||||
process_combine_contexts,
|
process_combine_contexts,
|
||||||
locate_json_string_body_from_string
|
locate_json_string_body_from_string,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
|
@@ -1,22 +1,22 @@
|
|||||||
accelerate
|
accelerate
|
||||||
|
aioboto3
|
||||||
aiohttp
|
aiohttp
|
||||||
|
|
||||||
|
# database packages
|
||||||
|
graspologic
|
||||||
|
hnswlib
|
||||||
|
nano-vectordb
|
||||||
|
neo4j
|
||||||
|
networkx
|
||||||
|
ollama
|
||||||
|
openai
|
||||||
|
oracledb
|
||||||
pyvis
|
pyvis
|
||||||
tenacity
|
tenacity
|
||||||
xxhash
|
|
||||||
# lmdeploy[all]
|
# lmdeploy[all]
|
||||||
|
|
||||||
# LLM packages
|
# LLM packages
|
||||||
tiktoken
|
tiktoken
|
||||||
torch
|
torch
|
||||||
transformers
|
transformers
|
||||||
aioboto3
|
xxhash
|
||||||
ollama
|
|
||||||
openai
|
|
||||||
|
|
||||||
# database packages
|
|
||||||
graspologic
|
|
||||||
hnswlib
|
|
||||||
networkx
|
|
||||||
oracledb
|
|
||||||
nano-vectordb
|
|
||||||
neo4j
|
|
||||||
|
Reference in New Issue
Block a user