cleanup code

This commit is contained in:
Yannick Stephan
2025-02-18 16:55:48 +01:00
parent 99dc4859a9
commit 46e1865b98
4 changed files with 33 additions and 55 deletions

View File

@@ -140,8 +140,6 @@ class OracleDB:
await cursor.execute(sql, params) await cursor.execute(sql, params)
except Exception as e: except Exception as e:
logger.error(f"Oracle database error: {e}") logger.error(f"Oracle database error: {e}")
print(sql)
print(params)
raise raise
columns = [column[0].lower() for column in cursor.description] columns = [column[0].lower() for column in cursor.description]
if multirows: if multirows:
@@ -172,8 +170,6 @@ class OracleDB:
await connection.commit() await connection.commit()
except Exception as e: except Exception as e:
logger.error(f"Oracle database error: {e}") logger.error(f"Oracle database error: {e}")
print(sql)
print(data)
raise raise
@@ -349,9 +345,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
"top_k": top_k, "top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold, "better_than_threshold": self.cosine_better_than_threshold,
} }
# print(SQL)
results = await self.db.query(SQL, params=params, multirows=True) results = await self.db.query(SQL, params=params, multirows=True)
# print("vector search result:",results)
return results return results
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
@@ -477,8 +471,6 @@ class OracleGraphStorage(BaseGraphStorage):
"""根据节点id检查节点是否存在""" """根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"] SQL = SQL_TEMPLATES["has_node"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
# print(self.db.workspace, node_id)
res = await self.db.query(SQL, params) res = await self.db.query(SQL, params)
if res: if res:
# print("Node exist!",res) # print("Node exist!",res)
@@ -494,7 +486,6 @@ class OracleGraphStorage(BaseGraphStorage):
"source_node_id": source_node_id, "source_node_id": source_node_id,
"target_node_id": target_node_id, "target_node_id": target_node_id,
} }
# print(SQL)
res = await self.db.query(SQL, params) res = await self.db.query(SQL, params)
if res: if res:
# print("Edge exist!",res) # print("Edge exist!",res)
@@ -506,33 +497,25 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
SQL = SQL_TEMPLATES["node_degree"] SQL = SQL_TEMPLATES["node_degree"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
res = await self.db.query(SQL, params) res = await self.db.query(SQL, params)
if res: if res:
# print("Node degree",res["degree"])
return res["degree"] return res["degree"]
else: else:
# print("Edge not exist!")
return 0 return 0
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""根据源和目标节点id获取边的度""" """根据源和目标节点id获取边的度"""
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
# print("Edge degree",degree)
return degree return degree
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""根据节点id获取节点数据""" """根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"] SQL = SQL_TEMPLATES["get_node"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
# print(self.db.workspace, node_id)
# print(SQL)
res = await self.db.query(SQL, params) res = await self.db.query(SQL, params)
if res: if res:
# print("Get node!",self.db.workspace, node_id,res)
return res return res
else: else:
# print("Can't get node!",self.db.workspace, node_id)
return None return None
async def get_edge( async def get_edge(

View File

@@ -136,9 +136,9 @@ class PostgreSQLDB:
data = None data = None
return data return data
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database error: {e}") logger.error(
print(sql) f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
print(params) )
raise raise
async def execute( async def execute(
@@ -167,9 +167,7 @@ class PostgreSQLDB:
else: else:
logger.error(f"Upsert error: {e}") logger.error(f"Upsert error: {e}")
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database error: {e.__class__} - {e}") logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
print(sql)
print(data)
raise raise
@staticmethod @staticmethod
@@ -266,9 +264,10 @@ class PGKVStorage(BaseKVStorage):
new_keys = set([s for s in keys if s not in exist_keys]) new_keys = set([s for s in keys if s not in exist_keys])
return new_keys return new_keys
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database error: {e}") logger.error(
print(sql) f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
print(params) )
raise
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
@@ -333,9 +332,9 @@ class PGVectorStorage(BaseVectorStorage):
"content_vector": json.dumps(item["__vector__"].tolist()), "content_vector": json.dumps(item["__vector__"].tolist()),
} }
except Exception as e: except Exception as e:
logger.error(f"Error to prepare upsert sql: {e}") logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}")
print(item) raise
raise e
return upsert_sql, data return upsert_sql, data
def _upsert_entities(self, item: dict): def _upsert_entities(self, item: dict):
@@ -454,9 +453,10 @@ class PGDocStatusStorage(DocStatusStorage):
print(f"new_keys: {new_keys}") print(f"new_keys: {new_keys}")
return new_keys return new_keys
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database error: {e}") logger.error(
print(sql) f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
print(params) )
raise
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"

View File

@@ -76,9 +76,7 @@ class TiDB:
try: try:
result = conn.execute(text(sql), params) result = conn.execute(text(sql), params)
except Exception as e: except Exception as e:
logger.error(f"Tidb database error: {e}") logger.error(f"Tidb database,\nsql:{sql},\nparams:{params},\nerror:{e}")
print(sql)
print(params)
raise raise
if multirows: if multirows:
rows = result.all() rows = result.all()
@@ -103,9 +101,7 @@ class TiDB:
else: else:
conn.execute(text(sql), parameters=data) conn.execute(text(sql), parameters=data)
except Exception as e: except Exception as e:
logger.error(f"TiDB database error: {e}") logger.error(f"Tidb database,\nsql:{sql},\ndata:{data},\nerror:{e}")
print(sql)
print(data)
raise raise
@@ -145,8 +141,7 @@ class TiDBKVStorage(BaseKVStorage):
try: try:
await self.db.query(SQL) await self.db.query(SQL)
except Exception as e: except Exception as e:
logger.error(f"Tidb database error: {e}") logger.error(f"Tidb database,\nsql:{SQL},\nkeys:{keys},\nerror:{e}")
print(SQL)
res = await self.db.query(SQL, multirows=True) res = await self.db.query(SQL, multirows=True)
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]

View File

@@ -77,7 +77,7 @@ from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__ from lightrag.api import __api_version__
import numpy as np import numpy as np
from typing import Union from typing import Any, Union
class InvalidResponseError(Exception): class InvalidResponseError(Exception):
@@ -94,13 +94,13 @@ class InvalidResponseError(Exception):
), ),
) )
async def openai_complete_if_cache( async def openai_complete_if_cache(
model, model: str,
prompt, prompt: str,
system_prompt=None, system_prompt: str | None = None,
history_messages=None, history_messages: list[dict[str, Any]] | None = None,
base_url=None, base_url: str | None = None,
api_key=None, api_key: str | None = None,
**kwargs, **kwargs: Any,
) -> str: ) -> str:
if history_messages is None: if history_messages is None:
history_messages = [] history_messages = []
@@ -125,7 +125,7 @@ async def openai_complete_if_cache(
) )
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
messages = [] messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) messages.extend(history_messages)
@@ -147,18 +147,18 @@ async def openai_complete_if_cache(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
except APIConnectionError as e: except APIConnectionError as e:
logger.error(f"OpenAI API Connection Error: {str(e)}") logger.error(f"OpenAI API Connection Error: {e}")
raise raise
except RateLimitError as e: except RateLimitError as e:
logger.error(f"OpenAI API Rate Limit Error: {str(e)}") logger.error(f"OpenAI API Rate Limit Error: {e}")
raise raise
except APITimeoutError as e: except APITimeoutError as e:
logger.error(f"OpenAI API Timeout Error: {str(e)}") logger.error(f"OpenAI API Timeout Error: {e}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"OpenAI API Call Failed: {str(e)}") logger.error(
logger.error(f"Model: {model}") f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}"
logger.error(f"Request parameters: {kwargs}") )
raise raise
if hasattr(response, "__aiter__"): if hasattr(response, "__aiter__"):