use oracle bind variables to avoid error

This commit is contained in:
jin
2024-11-15 12:57:01 +08:00
parent 41599897fb
commit 662303f605
4 changed files with 193 additions and 146 deletions

View File

@@ -17,6 +17,7 @@ T = TypeVar("T")
class QueryParam: class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global" mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60 top_k: int = 60

View File

@@ -114,16 +114,17 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database") logger.info("Finished check all tables in Oracle database")
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor: with connection.cursor() as cursor:
try: try:
await cursor.execute(sql) await cursor.execute(sql, params)
except Exception as e: except Exception as e:
logger.error(f"Oracle database error: {e}") logger.error(f"Oracle database error: {e}")
print(sql) print(sql)
print(params)
raise raise
columns = [column[0].lower() for column in cursor.description] columns = [column[0].lower() for column in cursor.description]
if multirows: if multirows:
@@ -140,7 +141,7 @@ class OracleDB:
data = None data = None
return data return data
async def execute(self, sql: str, data: list = None): async def execute(self, sql: str, data: list | dict = None):
# logger.info("go into OracleDB execute method") # logger.info("go into OracleDB execute method")
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict, None]: async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据.""" """根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
workspace=self.db.workspace, id=id params = {"workspace":self.db.workspace, "id":id}
)
# print("get_by_id:"+SQL) # print("get_by_id:"+SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
data = res # {"data":res} data = res # {"data":res}
# print (data) # print (data)
@@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
# Query by id # Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据""" """根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) params = {"workspace":self.db.workspace}
)
#print("get_by_ids:"+SQL) #print("get_by_ids:"+SQL)
res = await self.db.query(SQL, multirows=True) #print(params)
res = await self.db.query(SQL,params, multirows=True)
if res: if res:
data = res # [{"data":i} for i in res] data = res # [{"data":i} for i in res]
# print(data) # print(data)
@@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容""" """过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]))
workspace=self.db.workspace, params = {"workspace":self.db.workspace}
ids=",".join([f"'{k}'" for k in keys]), try:
) await self.db.query(SQL, params)
res = await self.db.query(SQL, multirows=True) except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
res = await self.db.query(SQL, params,multirows=True)
data = None data = None
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]
@@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
# print(list_data) # print(list_data)
for item in list_data: for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) merge_sql = SQL_TEMPLATES["merge_chunk"]
data = {"check_id":item["__id__"],
values = [ "id":item["__id__"],
item["__id__"], "content":item["content"],
item["content"], "workspace":self.db.workspace,
self.db.workspace, "tokens":item["tokens"],
item["tokens"], "chunk_order_index":item["chunk_order_index"],
item["chunk_order_index"], "full_doc_id":item["full_doc_id"],
item["full_doc_id"], "content_vector":item["__vector__"]
item["__vector__"], }
]
# print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, values) await self.db.execute(merge_sql, data)
if self.namespace == "full_docs": if self.namespace == "full_docs":
for k, v in self._data.items(): for k, v in self._data.items():
# values.clear() # values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"].format( merge_sql = SQL_TEMPLATES["merge_doc_full"]
check_id=k, data = {
) "check_id":k,
values = [k, self._data[k]["content"], self.db.workspace] "id":k,
"content":v["content"],
"workspace":self.db.workspace
}
# print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, values) await self.db.execute(merge_sql, data)
return left_data return left_data
async def index_done_callback(self): async def index_done_callback(self):
@@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度 # 转换精度
dtype = str(embedding.dtype).upper() dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0] dimension = embedding.shape[0]
embedding_string = ", ".join(map(str, embedding.tolist())) embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
SQL = SQL_TEMPLATES[self.namespace].format( SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
embedding_string=embedding_string, params = {
dimension=dimension, "embedding_string": embedding_string,
dtype=dtype, "workspace": self.db.workspace,
workspace=self.db.workspace, "top_k": top_k,
top_k=top_k, "better_than_threshold": self.cosine_better_than_threshold,
better_than_threshold=self.cosine_better_than_threshold, }
)
# print(SQL) # print(SQL)
results = await self.db.query(SQL, multirows=True) results = await self.db.query(SQL,params=params, multirows=True)
# print("vector search result:",results) # print("vector search result:",results)
return results return results
@@ -339,22 +344,18 @@ class OracleGraphStorage(BaseGraphStorage):
) )
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0] content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_node"].format( merge_sql = SQL_TEMPLATES["merge_node"]
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id data = {
) "workspace":self.db.workspace,
"name":entity_name,
"entity_type":entity_type,
"description":description,
"source_chunk_id":source_id,
"content":content,
"content_vector":content_vector
}
# print(merge_sql) # print(merge_sql)
await self.db.execute( await self.db.execute(merge_sql,data)
merge_sql,
[
self.db.workspace,
entity_name,
entity_type,
description,
source_id,
content,
content_vector,
],
)
# self._graph.add_node(node_id, **node_data) # self._graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
@@ -379,27 +380,20 @@ class OracleGraphStorage(BaseGraphStorage):
) )
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0] content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_edge"].format( merge_sql = SQL_TEMPLATES["merge_edge"]
workspace=self.db.workspace, data = {
source_name=source_name, "workspace":self.db.workspace,
target_name=target_name, "source_name":source_name,
source_chunk_id=source_chunk_id, "target_name":target_name,
) "weight":weight,
"keywords":keywords,
"description":description,
"source_chunk_id":source_chunk_id,
"content":content,
"content_vector":content_vector
}
# print(merge_sql) # print(merge_sql)
await self.db.execute( await self.db.execute(merge_sql,data)
merge_sql,
[
self.db.workspace,
source_name,
target_name,
weight,
keywords,
description,
source_chunk_id,
content,
content_vector,
],
)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -429,12 +423,14 @@ class OracleGraphStorage(BaseGraphStorage):
#################### query method ################# #################### query method #################
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在""" """根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"].format( SQL = SQL_TEMPLATES["has_node"]
workspace=self.db.workspace, node_id=node_id params = {
) "workspace":self.db.workspace,
"node_id":node_id
}
# print(SQL) # print(SQL)
# print(self.db.workspace, node_id) # print(self.db.workspace, node_id)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Node exist!",res) # print("Node exist!",res)
return True return True
@@ -444,13 +440,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在""" """根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"].format( SQL = SQL_TEMPLATES["has_edge"]
workspace=self.db.workspace, params = {
source_node_id=source_node_id, "workspace":self.db.workspace,
target_node_id=target_node_id, "source_node_id":source_node_id,
) "target_node_id":target_node_id
}
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Edge exist!",res) # print("Edge exist!",res)
return True return True
@@ -460,11 +457,13 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度""" """根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"].format( SQL = SQL_TEMPLATES["node_degree"]
workspace=self.db.workspace, node_id=node_id params = {
) "workspace":self.db.workspace,
"node_id":node_id
}
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Node degree",res["degree"]) # print("Node degree",res["degree"])
return res["degree"] return res["degree"]
@@ -480,12 +479,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据""" """根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"].format( SQL = SQL_TEMPLATES["get_node"]
workspace=self.db.workspace, node_id=node_id params = {
) "workspace":self.db.workspace,
"node_id":node_id
}
# print(self.db.workspace, node_id) # print(self.db.workspace, node_id)
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Get node!",self.db.workspace, node_id,res) # print("Get node!",self.db.workspace, node_id,res)
return res return res
@@ -497,12 +498,13 @@ class OracleGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
"""根据源和目标节点id获取边""" """根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"].format( SQL = SQL_TEMPLATES["get_edge"]
workspace=self.db.workspace, params = {
source_node_id=source_node_id, "workspace":self.db.workspace,
target_node_id=target_node_id, "source_node_id":source_node_id,
) "target_node_id":target_node_id
res = await self.db.query(SQL) }
res = await self.db.query(SQL,params)
if res: if res:
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res return res
@@ -513,10 +515,12 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边""" """根据节点id获取节点的所有边"""
if await self.has_node(source_node_id): if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"].format( SQL = SQL_TEMPLATES["get_node_edges"]
workspace=self.db.workspace, source_node_id=source_node_id params = {
) "workspace":self.db.workspace,
res = await self.db.query(sql=SQL, multirows=True) "source_node_id":source_node_id
}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res: if res:
data = [(i["source_name"], i["target_name"]) for i in res] data = [(i["source_name"], i["target_name"]) for i in res]
# print("Get node edge!",self.db.workspace, source_node_id,data) # print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -525,7 +529,21 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Node Edge not exist!",self.db.workspace, source_node_id) # print("Node Edge not exist!",self.db.workspace, source_node_id)
return [] return []
async def get_all_nodes(self, limit: int):
"""查询所有节点"""
SQL = SQL_TEMPLATES["get_all_nodes"]
params = {"workspace":self.db.workspace, "limit":str(limit)}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
async def get_all_edges(self, limit: int):
"""查询所有边"""
SQL = SQL_TEMPLATES["get_all_edges"]
params = {"workspace":self.db.workspace, "limit":str(limit)}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", "full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS", "text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -619,82 +637,96 @@ TABLES = {
SQL_TEMPLATES = { SQL_TEMPLATES = {
# SQL for KVStorage # SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL USING DUAL
ON (a.id = '{check_id}') ON (a.id = :check_id)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:1,:2,:3) INSERT(id,content,workspace) values(:id,:content,:workspace)
""", """,
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
USING DUAL USING DUAL
ON (a.id = '{check_id}') ON (a.id = :check_id)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:1,:2,:3,:4,:5,:6,:7) """, values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
# SQL for VectorStorage # SQL for VectorStorage
"entities": """SELECT name as entity_name FROM "entities": """SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
"chunks": """SELECT id FROM "chunks": """SELECT id FROM
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
# SQL for GraphStorage # SQL for GraphStorage
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a) MATCH (a)
WHERE a.workspace='{workspace}' AND a.name='{node_id}' WHERE a.workspace=:workspace AND a.name=:node_id
COLUMNS (a.name))""", COLUMNS (a.name))""",
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a) -[e]-> (b) MATCH (a) -[e]-> (b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name='{source_node_id}' AND b.name='{target_node_id}' AND a.name=:source_node_id AND b.name=:target_node_id
COLUMNS (e.source_name,e.target_name) )""", COLUMNS (e.source_name,e.target_name) )""",
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) MATCH (a)-[e]->(b)
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name='{node_id}' or b.name = '{node_id}' AND a.name=:node_id or b.name = :node_id
COLUMNS (a.name))""", COLUMNS (a.name))""",
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
FROM GRAPH_TABLE (lightrag_graph FROM GRAPH_TABLE (lightrag_graph
MATCH (a) MATCH (a)
WHERE a.workspace='{workspace}' AND a.name='{node_id}' WHERE a.workspace=:workspace AND a.name=:node_id
COLUMNS (a.name) COLUMNS (a.name)
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
WHERE t2.workspace='{workspace}'""", WHERE t2.workspace=:workspace""",
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
FROM GRAPH_TABLE (lightrag_graph FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) MATCH (a)-[e]->(b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name='{source_node_id}' and b.name = '{target_node_id}' AND a.name=:source_node_id and b.name = :target_node_id
COLUMNS (e.id,a.name as source_id) COLUMNS (e.id,a.name as source_id)
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
"get_node_edges": """SELECT source_name,target_name "get_node_edges": """SELECT source_name,target_name
FROM GRAPH_TABLE (lightrag_graph FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) MATCH (a)-[e]->(b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name='{source_node_id}' AND a.name=:source_node_id
COLUMNS (a.name as source_name,b.name as target_name))""", COLUMNS (a.name as source_name,b.name as target_name))""",
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
USING DUAL USING DUAL
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
values (:1,:2,:3,:4,:5,:6,:7) """, values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
USING DUAL USING DUAL
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
"get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_NODES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only
""",
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only"""
} }

View File

@@ -405,12 +405,13 @@ async def local_query(
kw_prompt = kw_prompt_temp.format(query=query) kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result) json_text = locate_json_string_body_from_string(result)
logger.debug("local_query json_text:", json_text)
try: try:
keywords_data = json.loads(json_text) keywords_data = json.loads(json_text)
keywords = keywords_data.get("low_level_keywords", []) keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords) keywords = ", ".join(keywords)
except json.JSONDecodeError: except json.JSONDecodeError:
print(result)
try: try:
result = ( result = (
result.replace(kw_prompt[:-1], "") result.replace(kw_prompt[:-1], "")
@@ -443,6 +444,8 @@ async def local_query(
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
) )
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -672,7 +675,7 @@ async def global_query(
kw_prompt = kw_prompt_temp.format(query=query) kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result) json_text = locate_json_string_body_from_string(result)
logger.debug("global json_text:", json_text)
try: try:
keywords_data = json.loads(json_text) keywords_data = json.loads(json_text)
keywords = keywords_data.get("high_level_keywords", []) keywords = keywords_data.get("high_level_keywords", [])
@@ -714,6 +717,8 @@ async def global_query(
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
) )
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -914,6 +919,7 @@ async def hybrid_query(
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result) json_text = locate_json_string_body_from_string(result)
logger.debug("hybrid_query json_text:", json_text)
try: try:
keywords_data = json.loads(json_text) keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
@@ -969,6 +975,8 @@ async def hybrid_query(
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
) )
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -1079,6 +1087,8 @@ async def naive_query(
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type content_data=section, response_type=query_param.response_type
) )
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,

View File

@@ -49,7 +49,11 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string""" """Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL) maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None: if maybe_json_str is not None:
return maybe_json_str.group(0) maybe_json_str = maybe_json_str.group(0)
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
return maybe_json_str
else: else:
return None return None