From d904f2abd55063d639202704a327d1f2458ff4ab Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Tue, 12 Nov 2024 15:38:16 +0800
Subject: [PATCH 01/22] Update lightrag_api_oracle_demo..py
---
examples/lightrag_api_oracle_demo..py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py
index 3bfae452..9b4e2741 100644
--- a/examples/lightrag_api_oracle_demo..py
+++ b/examples/lightrag_api_oracle_demo..py
@@ -149,13 +149,13 @@ class Response(BaseModel):
# API routes
-rag = None # 定义为全局对象
+rag = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
- rag = await init() # 在应用启动时初始化 `rag`
+ rag = await init()
print("done!")
yield
From 83f8a5139c8f076bcb893eb91ac4f0c660364fb6 Mon Sep 17 00:00:00 2001
From: Rick Battle
-
+
From 186cd34a033d13c13317cba356e0076cffc113fe Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Thu, 14 Nov 2024 16:21:20 +0800
Subject: [PATCH 06/22] Update Discord Link
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index be6c1bb5..b62f01a1 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
- [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
- [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
- [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
-- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
+- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/E4HgTnck)! Welcome to join for sharing and discussions! 🎉🎉
- [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
From 5661d76860436f7bf5aef2e50d9ee4a59660146c Mon Sep 17 00:00:00 2001
From: Richard <164130786@qq.com>
Date: Fri, 15 Nov 2024 13:11:43 +0800
Subject: [PATCH 07/22] fix neo4j bug
---
lightrag/kg/neo4j_impl.py | 1 +
lightrag/lightrag.py | 3 +--
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index e6b33a9b..32bfbe2e 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -214,6 +214,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
+ neo4jExceptions.ClientError
)
),
)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 67337098..ce27e76d 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -174,8 +174,7 @@ class LightRAG:
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation",
- global_config=asdict(self),
- embedding_func=self.embedding_func,
+ global_config=asdict(self)
)
####
# add embedding func by walter over
From 8e16f0815ce75a81797fb45d2f755204af9cd638 Mon Sep 17 00:00:00 2001
From: tmuife <43266626@qq.com>
Date: Mon, 18 Nov 2024 10:00:06 +0800
Subject: [PATCH 08/22] change the type of binding parameters in Oracle23AI
---
lightrag/kg/oracle_impl.py | 320 +++++++++++++++++++++----------------
1 file changed, 178 insertions(+), 142 deletions(-)
diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py
index 96a9e795..fd8bf536 100644
--- a/lightrag/kg/oracle_impl.py
+++ b/lightrag/kg/oracle_impl.py
@@ -114,16 +114,17 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
- async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
+ async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor:
try:
- await cursor.execute(sql)
+ await cursor.execute(sql, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(sql)
+ print(params)
raise
columns = [column[0].lower() for column in cursor.description]
if multirows:
@@ -140,7 +141,7 @@ class OracleDB:
data = None
return data
- async def execute(self, sql: str, data: list = None):
+ async def execute(self, sql: str, data: list | dict = None):
# logger.info("go into OracleDB execute method")
try:
async with self.pool.acquire() as connection:
@@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
- workspace=self.db.workspace, id=id
- )
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
+ params = {"workspace":self.db.workspace, "id":id}
# print("get_by_id:"+SQL)
- res = await self.db.query(SQL)
+ res = await self.db.query(SQL,params)
if res:
data = res # {"data":res}
# print (data)
@@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
- workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
- )
- # print("get_by_ids:"+SQL)
- res = await self.db.query(SQL, multirows=True)
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
+ params = {"workspace":self.db.workspace}
+ #print("get_by_ids:"+SQL)
+ #print(params)
+ res = await self.db.query(SQL,params, multirows=True)
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
- SQL = SQL_TEMPLATES["filter_keys"].format(
- table_name=N_T[self.namespace],
- workspace=self.db.workspace,
- ids=",".join([f"'{k}'" for k in keys]),
- )
- res = await self.db.query(SQL, multirows=True)
+ SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
+ ids=",".join([f"'{id}'" for id in keys]))
+ params = {"workspace":self.db.workspace}
+ try:
+ await self.db.query(SQL, params)
+ except Exception as e:
+ logger.error(f"Oracle database error: {e}")
+ print(SQL)
+ print(params)
+ res = await self.db.query(SQL, params,multirows=True)
data = None
if res:
exist_keys = [key["id"] for key in res]
@@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
d["__vector__"] = embeddings[i]
# print(list_data)
for item in list_data:
- merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
-
- values = [
- item["__id__"],
- item["content"],
- self.db.workspace,
- item["tokens"],
- item["chunk_order_index"],
- item["full_doc_id"],
- item["__vector__"],
- ]
+ merge_sql = SQL_TEMPLATES["merge_chunk"]
+ data = {"check_id":item["__id__"],
+ "id":item["__id__"],
+ "content":item["content"],
+ "workspace":self.db.workspace,
+ "tokens":item["tokens"],
+ "chunk_order_index":item["chunk_order_index"],
+ "full_doc_id":item["full_doc_id"],
+ "content_vector":item["__vector__"]
+ }
# print(merge_sql)
- await self.db.execute(merge_sql, values)
+ await self.db.execute(merge_sql, data)
if self.namespace == "full_docs":
for k, v in self._data.items():
# values.clear()
- merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
- check_id=k,
- )
- values = [k, self._data[k]["content"], self.db.workspace]
+ merge_sql = SQL_TEMPLATES["merge_doc_full"]
+ data = {
+ "check_id":k,
+ "id":k,
+ "content":v["content"],
+ "workspace":self.db.workspace
+ }
# print(merge_sql)
- await self.db.execute(merge_sql, values)
+ await self.db.execute(merge_sql, data)
return left_data
async def index_done_callback(self):
@@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度
dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0]
- embedding_string = ", ".join(map(str, embedding.tolist()))
+ embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
- SQL = SQL_TEMPLATES[self.namespace].format(
- embedding_string=embedding_string,
- dimension=dimension,
- dtype=dtype,
- workspace=self.db.workspace,
- top_k=top_k,
- better_than_threshold=self.cosine_better_than_threshold,
- )
+ SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
+ params = {
+ "embedding_string": embedding_string,
+ "workspace": self.db.workspace,
+ "top_k": top_k,
+ "better_than_threshold": self.cosine_better_than_threshold,
+ }
# print(SQL)
- results = await self.db.query(SQL, multirows=True)
+ results = await self.db.query(SQL,params=params, multirows=True)
# print("vector search result:",results)
return results
@@ -328,6 +333,8 @@ class OracleGraphStorage(BaseGraphStorage):
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
+ logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
+
content = entity_name + description
contents = [content]
batches = [
@@ -339,22 +346,18 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
- merge_sql = SQL_TEMPLATES["merge_node"].format(
- workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
- )
+ merge_sql = SQL_TEMPLATES["merge_node"]
+ data = {
+ "workspace":self.db.workspace,
+ "name":entity_name,
+ "entity_type":entity_type,
+ "description":description,
+ "source_chunk_id":source_id,
+ "content":content,
+ "content_vector":content_vector
+ }
# print(merge_sql)
- await self.db.execute(
- merge_sql,
- [
- self.db.workspace,
- entity_name,
- entity_type,
- description,
- source_id,
- content,
- content_vector,
- ],
- )
+ await self.db.execute(merge_sql,data)
# self._graph.add_node(node_id, **node_data)
async def upsert_edge(
@@ -368,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
+ logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
+
content = keywords + source_name + target_name + description
contents = [content]
batches = [
@@ -379,27 +384,20 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
- merge_sql = SQL_TEMPLATES["merge_edge"].format(
- workspace=self.db.workspace,
- source_name=source_name,
- target_name=target_name,
- source_chunk_id=source_chunk_id,
- )
+ merge_sql = SQL_TEMPLATES["merge_edge"]
+ data = {
+ "workspace":self.db.workspace,
+ "source_name":source_name,
+ "target_name":target_name,
+ "weight":weight,
+ "keywords":keywords,
+ "description":description,
+ "source_chunk_id":source_chunk_id,
+ "content":content,
+ "content_vector":content_vector
+ }
# print(merge_sql)
- await self.db.execute(
- merge_sql,
- [
- self.db.workspace,
- source_name,
- target_name,
- weight,
- keywords,
- description,
- source_chunk_id,
- content,
- content_vector,
- ],
- )
+ await self.db.execute(merge_sql,data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -429,12 +427,14 @@ class OracleGraphStorage(BaseGraphStorage):
#################### query method #################
async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在"""
- SQL = SQL_TEMPLATES["has_node"].format(
- workspace=self.db.workspace, node_id=node_id
- )
+ SQL = SQL_TEMPLATES["has_node"]
+ params = {
+ "workspace":self.db.workspace,
+ "node_id":node_id
+ }
# print(SQL)
# print(self.db.workspace, node_id)
- res = await self.db.query(SQL)
+ res = await self.db.query(SQL,params)
if res:
# print("Node exist!",res)
return True
@@ -444,13 +444,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在"""
- SQL = SQL_TEMPLATES["has_edge"].format(
- workspace=self.db.workspace,
- source_node_id=source_node_id,
- target_node_id=target_node_id,
- )
+ SQL = SQL_TEMPLATES["has_edge"]
+ params = {
+ "workspace":self.db.workspace,
+ "source_node_id":source_node_id,
+ "target_node_id":target_node_id
+ }
# print(SQL)
- res = await self.db.query(SQL)
+ res = await self.db.query(SQL,params)
if res:
# print("Edge exist!",res)
return True
@@ -460,11 +461,13 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
- SQL = SQL_TEMPLATES["node_degree"].format(
- workspace=self.db.workspace, node_id=node_id
- )
+ SQL = SQL_TEMPLATES["node_degree"]
+ params = {
+ "workspace":self.db.workspace,
+ "node_id":node_id
+ }
# print(SQL)
- res = await self.db.query(SQL)
+ res = await self.db.query(SQL,params)
if res:
# print("Node degree",res["degree"])
return res["degree"]
@@ -480,12 +483,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据"""
- SQL = SQL_TEMPLATES["get_node"].format(
- workspace=self.db.workspace, node_id=node_id
- )
+ SQL = SQL_TEMPLATES["get_node"]
+ params = {
+ "workspace":self.db.workspace,
+ "node_id":node_id
+ }
# print(self.db.workspace, node_id)
# print(SQL)
- res = await self.db.query(SQL)
+ res = await self.db.query(SQL,params)
if res:
# print("Get node!",self.db.workspace, node_id,res)
return res
@@ -497,12 +502,13 @@ class OracleGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""根据源和目标节点id获取边"""
- SQL = SQL_TEMPLATES["get_edge"].format(
- workspace=self.db.workspace,
- source_node_id=source_node_id,
- target_node_id=target_node_id,
- )
- res = await self.db.query(SQL)
+ SQL = SQL_TEMPLATES["get_edge"]
+ params = {
+ "workspace":self.db.workspace,
+ "source_node_id":source_node_id,
+ "target_node_id":target_node_id
+ }
+ res = await self.db.query(SQL,params)
if res:
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res
@@ -513,10 +519,12 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id):
- SQL = SQL_TEMPLATES["get_node_edges"].format(
- workspace=self.db.workspace, source_node_id=source_node_id
- )
- res = await self.db.query(sql=SQL, multirows=True)
+ SQL = SQL_TEMPLATES["get_node_edges"]
+ params = {
+ "workspace":self.db.workspace,
+ "source_node_id":source_node_id
+ }
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
# print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -524,8 +532,22 @@ class OracleGraphStorage(BaseGraphStorage):
else:
# print("Node Edge not exist!",self.db.workspace, source_node_id)
return []
+
+ async def get_all_nodes(self, limit: int):
+ """查询所有节点"""
+ SQL = SQL_TEMPLATES["get_all_nodes"]
+ params = {"workspace":self.db.workspace, "limit":str(limit)}
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
+ if res:
+ return res
-
+ async def get_all_edges(self, limit: int):
+ """查询所有边"""
+ SQL = SQL_TEMPLATES["get_all_edges"]
+ params = {"workspace":self.db.workspace, "limit":str(limit)}
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
+ if res:
+ return res
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -619,82 +641,96 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
- "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
- "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
- "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
- "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
- "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
+ "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
+ "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
+ "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
+ "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
+ "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
- ON (a.id = '{check_id}')
+ ON (a.id = :check_id)
WHEN NOT MATCHED THEN
- INSERT(id,content,workspace) values(:1,:2,:3)
+ INSERT(id,content,workspace) values(:id,:content,:workspace)
""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
USING DUAL
- ON (a.id = '{check_id}')
+ ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
- values (:1,:2,:3,:4,:5,:6,:7) """,
+ values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM
- (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
- FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
- WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
+ (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
+ FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
+ WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
- (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
- FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
- WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
+ (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
+ FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
+ WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
"chunks": """SELECT id FROM
- (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
- WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
+ (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
+ WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
# SQL for GraphStorage
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a)
- WHERE a.workspace='{workspace}' AND a.name='{node_id}'
+ WHERE a.workspace=:workspace AND a.name=:node_id
COLUMNS (a.name))""",
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a) -[e]-> (b)
- WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
- AND a.name='{source_node_id}' AND b.name='{target_node_id}'
+ WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
+ AND a.name=:source_node_id AND b.name=:target_node_id
COLUMNS (e.source_name,e.target_name) )""",
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b)
- WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
- AND a.name='{node_id}' or b.name = '{node_id}'
+ WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
+ AND a.name=:node_id or b.name = :node_id
COLUMNS (a.name))""",
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
FROM GRAPH_TABLE (lightrag_graph
MATCH (a)
- WHERE a.workspace='{workspace}' AND a.name='{node_id}'
+ WHERE a.workspace=:workspace AND a.name=:node_id
COLUMNS (a.name)
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
- WHERE t2.workspace='{workspace}'""",
+ WHERE t2.workspace=:workspace""",
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b)
- WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
- AND a.name='{source_node_id}' and b.name = '{target_node_id}'
+ WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
+ AND a.name=:source_node_id and b.name = :target_node_id
COLUMNS (e.id,a.name as source_id)
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
"get_node_edges": """SELECT source_name,target_name
FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b)
- WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
- AND a.name='{source_node_id}'
+ WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
+ AND a.name=:source_node_id
COLUMNS (a.name as source_name,b.name as target_name))""",
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
USING DUAL
- ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
+ ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
WHEN NOT MATCHED THEN
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
- values (:1,:2,:3,:4,:5,:6,:7) """,
+ values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
USING DUAL
- ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
+ ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
- values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
+ values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
+ "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
+ FROM LIGHTRAG_GRAPH_NODES t1
+ LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
+ WHERE t1.workspace=:workspace
+ order by t1.CREATETIME DESC
+ fetch first :limit rows only
+ """,
+ "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
+ t1.weight,t1.DESCRIPTION,t2.content
+ FROM LIGHTRAG_GRAPH_EDGES t1
+ LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
+ WHERE t1.workspace=:workspace
+ order by t1.CREATETIME DESC
+ fetch first :limit rows only"""
}
From 89d7967349fe30c1bf682935ce95d54082b7d9fe Mon Sep 17 00:00:00 2001
From: tmuife <43266626@qq.com>
Date: Mon, 18 Nov 2024 13:52:49 +0800
Subject: [PATCH 09/22] use pre-commit reformat
---
lightrag/kg/oracle_impl.py | 176 ++++++++++++++++++-------------------
1 file changed, 87 insertions(+), 89 deletions(-)
diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py
index fd8bf536..b46d36d8 100644
--- a/lightrag/kg/oracle_impl.py
+++ b/lightrag/kg/oracle_impl.py
@@ -114,7 +114,9 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
- async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
+ async def query(
+ self, sql: str, params: dict = None, multirows: bool = False
+ ) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
@@ -174,9 +176,9 @@ class OracleKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
- params = {"workspace":self.db.workspace, "id":id}
+ params = {"workspace": self.db.workspace, "id": id}
# print("get_by_id:"+SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL, params)
if res:
data = res # {"data":res}
# print (data)
@@ -187,11 +189,13 @@ class OracleKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
- params = {"workspace":self.db.workspace}
- #print("get_by_ids:"+SQL)
- #print(params)
- res = await self.db.query(SQL,params, multirows=True)
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
+ ids=",".join([f"'{id}'" for id in ids])
+ )
+ params = {"workspace": self.db.workspace}
+ # print("get_by_ids:"+SQL)
+ # print(params)
+ res = await self.db.query(SQL, params, multirows=True)
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -201,16 +205,17 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
- SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
- ids=",".join([f"'{id}'" for id in keys]))
- params = {"workspace":self.db.workspace}
+ SQL = SQL_TEMPLATES["filter_keys"].format(
+ table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
+ )
+ params = {"workspace": self.db.workspace}
try:
await self.db.query(SQL, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
- res = await self.db.query(SQL, params,multirows=True)
+ res = await self.db.query(SQL, params, multirows=True)
data = None
if res:
exist_keys = [key["id"] for key in res]
@@ -248,15 +253,16 @@ class OracleKVStorage(BaseKVStorage):
# print(list_data)
for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"]
- data = {"check_id":item["__id__"],
- "id":item["__id__"],
- "content":item["content"],
- "workspace":self.db.workspace,
- "tokens":item["tokens"],
- "chunk_order_index":item["chunk_order_index"],
- "full_doc_id":item["full_doc_id"],
- "content_vector":item["__vector__"]
- }
+ data = {
+ "check_id": item["__id__"],
+ "id": item["__id__"],
+ "content": item["content"],
+ "workspace": self.db.workspace,
+ "tokens": item["tokens"],
+ "chunk_order_index": item["chunk_order_index"],
+ "full_doc_id": item["full_doc_id"],
+ "content_vector": item["__vector__"],
+ }
# print(merge_sql)
await self.db.execute(merge_sql, data)
@@ -265,11 +271,11 @@ class OracleKVStorage(BaseKVStorage):
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"]
data = {
- "check_id":k,
- "id":k,
- "content":v["content"],
- "workspace":self.db.workspace
- }
+ "check_id": k,
+ "id": k,
+ "content": v["content"],
+ "workspace": self.db.workspace,
+ }
# print(merge_sql)
await self.db.execute(merge_sql, data)
return left_data
@@ -301,7 +307,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度
dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0]
- embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
+ embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
params = {
@@ -309,9 +315,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
"workspace": self.db.workspace,
"top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold,
- }
+ }
# print(SQL)
- results = await self.db.query(SQL,params=params, multirows=True)
+ results = await self.db.query(SQL, params=params, multirows=True)
# print("vector search result:",results)
return results
@@ -346,18 +352,18 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
- merge_sql = SQL_TEMPLATES["merge_node"]
+ merge_sql = SQL_TEMPLATES["merge_node"]
data = {
- "workspace":self.db.workspace,
- "name":entity_name,
- "entity_type":entity_type,
- "description":description,
- "source_chunk_id":source_id,
- "content":content,
- "content_vector":content_vector
- }
+ "workspace": self.db.workspace,
+ "name": entity_name,
+ "entity_type": entity_type,
+ "description": description,
+ "source_chunk_id": source_id,
+ "content": content,
+ "content_vector": content_vector,
+ }
# print(merge_sql)
- await self.db.execute(merge_sql,data)
+ await self.db.execute(merge_sql, data)
# self._graph.add_node(node_id, **node_data)
async def upsert_edge(
@@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
- logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
+ logger.debug(
+ f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
+ )
content = keywords + source_name + target_name + description
contents = [content]
@@ -384,20 +392,20 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
- merge_sql = SQL_TEMPLATES["merge_edge"]
+ merge_sql = SQL_TEMPLATES["merge_edge"]
data = {
- "workspace":self.db.workspace,
- "source_name":source_name,
- "target_name":target_name,
- "weight":weight,
- "keywords":keywords,
- "description":description,
- "source_chunk_id":source_chunk_id,
- "content":content,
- "content_vector":content_vector
- }
+ "workspace": self.db.workspace,
+ "source_name": source_name,
+ "target_name": target_name,
+ "weight": weight,
+ "keywords": keywords,
+ "description": description,
+ "source_chunk_id": source_chunk_id,
+ "content": content,
+ "content_vector": content_vector,
+ }
# print(merge_sql)
- await self.db.execute(merge_sql,data)
+ await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -428,13 +436,10 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"]
- params = {
- "workspace":self.db.workspace,
- "node_id":node_id
- }
+ params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
# print(self.db.workspace, node_id)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL, params)
if res:
# print("Node exist!",res)
return True
@@ -446,12 +451,12 @@ class OracleGraphStorage(BaseGraphStorage):
"""根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"]
params = {
- "workspace":self.db.workspace,
- "source_node_id":source_node_id,
- "target_node_id":target_node_id
- }
+ "workspace": self.db.workspace,
+ "source_node_id": source_node_id,
+ "target_node_id": target_node_id,
+ }
# print(SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL, params)
if res:
# print("Edge exist!",res)
return True
@@ -462,12 +467,9 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"]
- params = {
- "workspace":self.db.workspace,
- "node_id":node_id
- }
+ params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL, params)
if res:
# print("Node degree",res["degree"])
return res["degree"]
@@ -484,13 +486,10 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"]
- params = {
- "workspace":self.db.workspace,
- "node_id":node_id
- }
+ params = {"workspace": self.db.workspace, "node_id": node_id}
# print(self.db.workspace, node_id)
# print(SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL, params)
if res:
# print("Get node!",self.db.workspace, node_id,res)
return res
@@ -504,11 +503,11 @@ class OracleGraphStorage(BaseGraphStorage):
"""根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"]
params = {
- "workspace":self.db.workspace,
- "source_node_id":source_node_id,
- "target_node_id":target_node_id
- }
- res = await self.db.query(SQL,params)
+ "workspace": self.db.workspace,
+ "source_node_id": source_node_id,
+ "target_node_id": target_node_id,
+ }
+ res = await self.db.query(SQL, params)
if res:
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res
@@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage):
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"]
- params = {
- "workspace":self.db.workspace,
- "source_node_id":source_node_id
- }
+ params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
@@ -532,22 +528,24 @@ class OracleGraphStorage(BaseGraphStorage):
else:
# print("Node Edge not exist!",self.db.workspace, source_node_id)
return []
-
+
async def get_all_nodes(self, limit: int):
"""查询所有节点"""
SQL = SQL_TEMPLATES["get_all_nodes"]
- params = {"workspace":self.db.workspace, "limit":str(limit)}
- res = await self.db.query(sql=SQL,params=params, multirows=True)
+ params = {"workspace": self.db.workspace, "limit": str(limit)}
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
async def get_all_edges(self, limit: int):
"""查询所有边"""
SQL = SQL_TEMPLATES["get_all_edges"]
- params = {"workspace":self.db.workspace, "limit":str(limit)}
- res = await self.db.query(sql=SQL,params=params, multirows=True)
+ params = {"workspace": self.db.workspace, "limit": str(limit)}
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
+
+
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -719,18 +717,18 @@ SQL_TEMPLATES = {
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
- "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
+ "get_all_nodes": """SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_NODES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only
""",
- "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
+ "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
- fetch first :limit rows only"""
+ fetch first :limit rows only""",
}
From e203aad3de6fe36c61a6cba73911e7df8311feff Mon Sep 17 00:00:00 2001
From: WinstonCHEN1 <1281838223@qq.com>
Date: Mon, 18 Nov 2024 14:24:04 -0800
Subject: [PATCH 10/22] fix:error working directory name in Step_1.py
---
reproduce/Step_1.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/reproduce/Step_1.py b/reproduce/Step_1.py
index 43c44056..e318c145 100644
--- a/reproduce/Step_1.py
+++ b/reproduce/Step_1.py
@@ -24,7 +24,7 @@ def insert_text(rag, file_path):
cls = "agriculture"
-WORKING_DIR = "../{cls}"
+WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
From c9becdf5f40e6f3b291b9a3d50a9624d677e5d65 Mon Sep 17 00:00:00 2001
From: luoyifan <1625370020@qq.com>
Date: Tue, 19 Nov 2024 14:02:38 +0800
Subject: [PATCH 11/22] A more robust approach for result to json.
---
lightrag/operate.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/lightrag/operate.py b/lightrag/operate.py
index b11e14fe..eb600c4b 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -418,7 +418,7 @@ async def local_query(
.replace("model", "")
.strip()
)
- result = "{" + result.split("{")[1].split("}")[0] + "}"
+ result = "{" + result.split("{")[-1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
@@ -691,7 +691,7 @@ async def global_query(
.replace("model", "")
.strip()
)
- result = "{" + result.split("{")[1].split("}")[0] + "}"
+ result = "{" + result.split("{")[-1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
@@ -940,7 +940,7 @@ async def hybrid_query(
.replace("model", "")
.strip()
)
- result = "{" + result.split("{")[1].split("}")[0] + "}"
+ result = "{" + result.split("{")[-1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
From bcaaaad9598cf89b248af84d228b68a7220d1096 Mon Sep 17 00:00:00 2001
From: LarFii <834462287@qq.com>
Date: Tue, 19 Nov 2024 16:52:26 +0800
Subject: [PATCH 12/22] Update
---
README.md | 8 ++++++--
lightrag/kg/neo4j_impl.py | 2 +-
lightrag/lightrag.py | 3 +--
3 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index b62f01a1..d7e3d418 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,9 @@
+
+
@@ -35,8 +37,10 @@ This repository hosts the code of LightRAG. The structure of this code is based
## Algorithm Flowchart
-
-
+
+*Figure 1: LightRAG Indexing Flowchart*
+
+*Figure 2: LightRAG Retrieval and Querying Flowchart*
## Install
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index 32bfbe2e..f9fcb46d 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -214,7 +214,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
- neo4jExceptions.ClientError
+ neo4jExceptions.ClientError,
)
),
)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index ce27e76d..7fafadcf 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -173,8 +173,7 @@ class LightRAG:
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
- namespace="chunk_entity_relation",
- global_config=asdict(self)
+ namespace="chunk_entity_relation", global_config=asdict(self)
)
####
# add embedding func by walter over
From 9d871fbc71aff5c270b255ae742e3f66578c1789 Mon Sep 17 00:00:00 2001
From: LarFii <834462287@qq.com>
Date: Tue, 19 Nov 2024 16:54:14 +0800
Subject: [PATCH 13/22] Update README.md
---
README.md | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index d7e3d418..36e1b2a9 100644
--- a/README.md
+++ b/README.md
@@ -9,8 +9,6 @@
-
-