Fix the lint issue

This commit is contained in:
Samuel Chan
2025-01-04 18:49:32 +08:00
parent 11f889a9df
commit 6c1b669f0f
4 changed files with 155 additions and 98 deletions

View File

@@ -53,7 +53,7 @@ async def main():
kv_storage="PGKVStorage", kv_storage="PGKVStorage",
doc_status_storage="PGDocStatusStorage", doc_status_storage="PGDocStatusStorage",
graph_storage="PGGraphStorage", graph_storage="PGGraphStorage",
vector_storage="PGVectorStorage" vector_storage="PGVectorStorage",
) )
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.doc_status.db = postgres_db rag.doc_status.db = postgres_db
@@ -77,27 +77,35 @@ async def main():
start_time = time.time() start_time = time.time()
# Perform naive search # Perform naive search
print( print(
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="naive")) await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
) )
print(f"Naive Query Time: {time.time() - start_time} seconds") print(f"Naive Query Time: {time.time() - start_time} seconds")
# Perform local search # Perform local search
print("**** Start Local Query ****") print("**** Start Local Query ****")
start_time = time.time() start_time = time.time()
print( print(
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="local")) await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
) )
print(f"Local Query Time: {time.time() - start_time} seconds") print(f"Local Query Time: {time.time() - start_time} seconds")
# Perform global search # Perform global search
print("**** Start Global Query ****") print("**** Start Global Query ****")
start_time = time.time() start_time = time.time()
print( print(
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="global")) await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
) )
print(f"Global Query Time: {time.time() - start_time}") print(f"Global Query Time: {time.time() - start_time}")
# Perform hybrid search # Perform hybrid search
print("**** Start Hybrid Query ****") print("**** Start Hybrid Query ****")
print( print(
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="hybrid")) await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
) )
print(f"Hybrid Query Time: {time.time() - start_time} seconds") print(f"Hybrid Query Time: {time.time() - start_time} seconds")

View File

@@ -19,7 +19,11 @@ from tenacity import (
from ..utils import logger from ..utils import logger
from ..base import ( from ..base import (
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, DocStatusStorage, DocStatus, DocProcessingStatus, BaseGraphStorage, BaseVectorStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
BaseGraphStorage,
) )
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
@@ -36,14 +40,15 @@ class PostgreSQLDB:
self.user = config.get("user", "postgres") self.user = config.get("user", "postgres")
self.password = config.get("password", None) self.password = config.get("password", None)
self.database = config.get("database", "postgres") self.database = config.get("database", "postgres")
self.workspace = config.get("workspace", 'default') self.workspace = config.get("workspace", "default")
self.max = 12 self.max = 12
self.increment = 1 self.increment = 1
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier") logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
if self.user is None or self.password is None or self.database is None: if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database in addon_params") raise ValueError(
"Missing database user, password, or database in addon_params"
)
async def initdb(self): async def initdb(self):
try: try:
@@ -54,12 +59,16 @@ class PostgreSQLDB:
host=self.host, host=self.host,
port=self.port, port=self.port,
min_size=1, min_size=1,
max_size=self.max max_size=self.max,
) )
logger.info(f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}") logger.info(
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}") logger.error(
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
)
logger.error(f"PostgreSQL database error: {e}") logger.error(f"PostgreSQL database error: {e}")
raise raise
@@ -79,9 +88,13 @@ class PostgreSQLDB:
logger.info("Finished checking all tables in PostgreSQL database") logger.info("Finished checking all tables in PostgreSQL database")
async def query( async def query(
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None self,
sql: str,
params: dict = None,
multirows: bool = False,
for_age: bool = False,
graph_name: str = None,
) -> Union[dict, None, list[dict]]: ) -> Union[dict, None, list[dict]]:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
try: try:
@@ -111,7 +124,13 @@ class PostgreSQLDB:
print(params) print(params)
raise raise
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None): async def execute(
self,
sql: str,
data: Union[list, dict] = None,
for_age: bool = False,
graph_name: str = None,
):
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
if for_age: if for_age:
@@ -130,7 +149,7 @@ class PostgreSQLDB:
@staticmethod @staticmethod
async def _prerequisite(conn: asyncpg.Connection, graph_name: str): async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
try: try:
await conn.execute(f'SET search_path = ag_catalog, "$user", public') await conn.execute('SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""") await conn.execute(f"""select create_graph('{graph_name}')""")
except asyncpg.exceptions.InvalidSchemaNameError: except asyncpg.exceptions.InvalidSchemaNameError:
pass pass
@@ -191,7 +210,8 @@ class PGKVStorage(BaseKVStorage):
async def filter_keys(self, keys: List[str]) -> Set[str]: async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( sql = SQL_TEMPLATES["filter_keys"].format(
table_name=NAMESPACE_TABLE_MAP[self.namespace], ids=",".join([f"'{id}'" for id in keys]) table_name=NAMESPACE_TABLE_MAP[self.namespace],
ids=",".join([f"'{id}'" for id in keys]),
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
try: try:
@@ -207,7 +227,6 @@ class PGKVStorage(BaseKVStorage):
print(sql) print(sql)
print(params) print(params)
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: Dict[str, dict]): async def upsert(self, data: Dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
@@ -282,6 +301,7 @@ class PGVectorStorage(BaseVectorStorage):
"content_vector": json.dumps(item["__vector__"].tolist()), "content_vector": json.dumps(item["__vector__"].tolist()),
} }
return upsert_sql, data return upsert_sql, data
def _upsert_relationships(self, item: dict): def _upsert_relationships(self, item: dict):
upsert_sql = SQL_TEMPLATES["upsert_relationship"] upsert_sql = SQL_TEMPLATES["upsert_relationship"]
data = { data = {
@@ -340,8 +360,6 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
async def index_done_callback(self): async def index_done_callback(self):
logger.info("vector data had been saved into postgresql db!") logger.info("vector data had been saved into postgresql db!")
@@ -361,9 +379,11 @@ class PGVectorStorage(BaseVectorStorage):
results = await self.db.query(sql, params=params, multirows=True) results = await self.db.query(sql, params=params, multirows=True)
return results return results
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
"""PostgreSQL implementation of document status storage""" """PostgreSQL implementation of document status storage"""
db: PostgreSQLDB = None db: PostgreSQLDB = None
def __post_init__(self): def __post_init__(self):
@@ -372,41 +392,47 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, data: list[str]) -> set[str]: async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})" sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
result = await self.db.query(sql, {'workspace': self.db.workspace}, True) result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None: if result is None:
return set(data) return set(data)
else: else:
existed = set([element['id'] for element in result]) existed = set([element["id"] for element in result])
return set(data) - existed return set(data) - existed
async def get_status_counts(self) -> Dict[str, int]: async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
sql = '''SELECT status as "status", COUNT(1) as "count" sql = """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_STATUS FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS where workspace=$1 GROUP BY STATUS
''' """
result = await self.db.query(sql, {'workspace': self.db.workspace}, True) result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {} counts = {}
for doc in result: for doc in result:
counts[doc["status"]] = doc["count"] counts[doc["status"]] = doc["count"]
return counts return counts
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: async def get_docs_by_status(
self, status: DocStatus
) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status""" """Get all documents by status"""
sql = 'select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1' sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
params = {'workspace': self.db.workspace, 'status': status} params = {"workspace": self.db.workspace, "status": status}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, params, True)
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
# Converting to be a dict # Converting to be a dict
return {element["id"]: return {
DocProcessingStatus(content_summary=element["content_summary"], element["id"]: DocProcessingStatus(
content_summary=element["content_summary"],
content_length=element["content_length"], content_length=element["content_length"],
status=element["status"], status=element["status"],
created_at=element["created_at"], created_at=element["created_at"],
updated_at=element["updated_at"], updated_at=element["updated_at"],
chunks_count=element["chunks_count"]) for element in result} chunks_count=element["chunks_count"],
)
for element in result
}
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents""" """Get all failed documents"""
@@ -436,14 +462,17 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at = CURRENT_TIMESTAMP""" updated_at = CURRENT_TIMESTAMP"""
for k, v in data.items(): for k, v in data.items():
# chunks_count is optional # chunks_count is optional
await self.db.execute(sql, { await self.db.execute(
sql,
{
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": k, "id": k,
"content_summary": v["content_summary"], "content_summary": v["content_summary"],
"content_length": v["content_length"], "content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1, "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"], "status": v["status"],
}) },
)
return data return data
@@ -484,7 +513,6 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def index_done_callback(self): async def index_done_callback(self):
print("KG successfully indexed.") print("KG successfully indexed.")
@@ -669,7 +697,8 @@ class PGGraphStorage(BaseGraphStorage):
# get pgsql formatted field names # get pgsql formatted field names
fields = [ fields = [
PGGraphStorage._get_col_name(field, idx) for idx, field in enumerate(fields) PGGraphStorage._get_col_name(field, idx)
for idx, field in enumerate(fields)
] ]
# build resulting pgsql relation # build resulting pgsql relation
@@ -690,7 +719,9 @@ class PGGraphStorage(BaseGraphStorage):
projection=select_str, projection=select_str,
) )
async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]: async def _query(
self, query: str, readonly=True, upsert_edge=False, **params: str
) -> List[Dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result age compatible query, executing it and converting the result
@@ -708,14 +739,25 @@ class PGGraphStorage(BaseGraphStorage):
# execute the query, rolling back on an error # execute the query, rolling back on an error
try: try:
if readonly: if readonly:
data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name) data = await self.db.query(
wrapped_query,
multirows=True,
for_age=True,
graph_name=self.graph_name,
)
else: else:
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING) # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future. # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
if upsert_edge: if upsert_edge:
data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name) data = await self.db.execute(
f"{wrapped_query};{wrapped_query};",
for_age=True,
graph_name=self.graph_name,
)
else: else:
data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name) data = await self.db.execute(
wrapped_query, for_age=True, graph_name=self.graph_name
)
except Exception as e: except Exception as e:
raise PGGraphQueryException( raise PGGraphQueryException(
{ {
@@ -935,7 +977,9 @@ class PGGraphStorage(BaseGraphStorage):
source_node_label = source_node_id.strip('"') source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"') target_node_label = target_node_id.strip('"')
edge_properties = edge_data edge_properties = edge_data
logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}") logger.info(
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
)
query = """MATCH (source:`{src_label}`) query = """MATCH (source:`{src_label}`)
WITH source WITH source
@@ -1056,7 +1100,6 @@ TABLES = {
} }
SQL_TEMPLATES = { SQL_TEMPLATES = {
# SQL for KVStorage # SQL for KVStorage
"get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
@@ -1136,5 +1179,5 @@ SQL_TEMPLATES = {
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
WHERE distance>$2 ORDER BY distance DESC LIMIT $3 WHERE distance>$2 ORDER BY distance DESC LIMIT $3
""" """,
} }

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import asyncpg import asyncpg
import sys, os import sys
import os
import psycopg import psycopg
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
@@ -15,19 +16,24 @@ os.environ["AGE_GRAPH_NAME"] = "dickens"
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
import asyncio.windows_events import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def get_pool(): async def get_pool():
return await asyncpg.create_pool( return await asyncpg.create_pool(
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
min_size=10, min_size=10,
max_size=10, max_size=10,
max_queries=5000, max_queries=5000,
max_inactive_connection_lifetime=300.0 max_inactive_connection_lifetime=300.0,
) )
async def main1(): async def main1():
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" connection_string = (
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
)
pool = AsyncConnectionPool(connection_string, open=False) pool = AsyncConnectionPool(connection_string, open=False)
await pool.open() await pool.open()
@@ -36,7 +42,7 @@ async def main1():
async with conn.cursor() as curs: async with conn.cursor() as curs:
try: try:
await curs.execute('SET search_path = ag_catalog, "$user", public') await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute(f"SELECT create_graph('dickens-2')") await curs.execute("SELECT create_graph('dickens-2')")
await conn.commit() await conn.commit()
print("create_graph success") print("create_graph success")
except ( except (
@@ -48,6 +54,7 @@ async def main1():
finally: finally:
pass pass
db = PostgreSQLDB( db = PostgreSQLDB(
config={ config={
"host": "localhost", "host": "localhost",
@@ -58,6 +65,7 @@ db = PostgreSQLDB(
} }
) )
async def query_with_age(): async def query_with_age():
await db.initdb() await db.initdb()
graph = PGGraphStorage( graph = PGGraphStorage(
@@ -69,6 +77,7 @@ async def query_with_age():
res = await graph.get_node('"CHRISTMAS-TIME"') res = await graph.get_node('"CHRISTMAS-TIME"')
print("Node is: ", res) print("Node is: ", res)
async def create_edge_with_age(): async def create_edge_with_age():
await db.initdb() await db.initdb()
graph = PGGraphStorage( graph = PGGraphStorage(
@@ -89,31 +98,28 @@ async def create_edge_with_age():
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
}, },
) )
res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"') res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
print("Edge is: ", res) print("Edge is: ", res)
async def main(): async def main():
pool = await get_pool() pool = await get_pool()
# 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
# 专门用来设置一些数据库独有的某些属性
# 从池子中取出一个连接
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
# cypher = "MATCH (n:how_are_you_doing) RETURN n" # cypher = "MATCH (n:how_are_you_doing) RETURN n"
async with pool.acquire() as conn: async with pool.acquire() as conn:
try: try:
await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""") await conn.execute(
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
)
except asyncpg.exceptions.InvalidSchemaNameError: except asyncpg.exceptions.InvalidSchemaNameError:
print("create_graph already exists") print("create_graph already exists")
# stmt = await conn.prepare(sql) # stmt = await conn.prepare(sql)
row = await conn.fetch(sql) row = await conn.fetch(sql)
print("row is: ", row) print("row is: ", row)
# 解决办法就是起一个别名
row = await conn.fetchrow("select '100'::int + 200 as result") row = await conn.fetchrow("select '100'::int + 200 as result")
print(row) # <Record result=300> print(row) # <Record result=300>
# 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
if __name__ == '__main__': if __name__ == "__main__":
asyncio.run(query_with_age()) asyncio.run(query_with_age())

View File

@@ -1,6 +1,8 @@
accelerate accelerate
aioboto3~=13.3.0 aioboto3~=13.3.0
aiofiles~=24.1.0
aiohttp~=3.11.11 aiohttp~=3.11.11
asyncpg~=0.30.0
# database packages # database packages
graspologic graspologic
@@ -9,14 +11,20 @@ hnswlib
nano-vectordb nano-vectordb
neo4j~=5.27.0 neo4j~=5.27.0
networkx~=3.2.1 networkx~=3.2.1
numpy~=2.2.0
ollama~=0.4.4 ollama~=0.4.4
openai~=1.58.1 openai~=1.58.1
oracledb oracledb
psycopg-pool~=3.2.4
psycopg[binary,pool]~=3.2.3 psycopg[binary,pool]~=3.2.3
pydantic~=2.10.4
pymilvus pymilvus
pymongo pymongo
pymysql pymysql
python-dotenv~=1.0.1
pyvis~=0.3.2 pyvis~=0.3.2
setuptools~=70.0.0
# lmdeploy[all] # lmdeploy[all]
sqlalchemy~=2.0.36 sqlalchemy~=2.0.36
tenacity~=9.0.0 tenacity~=9.0.0
@@ -25,14 +33,6 @@ tenacity~=9.0.0
# LLM packages # LLM packages
tiktoken~=0.8.0 tiktoken~=0.8.0
torch~=2.5.1+cu121 torch~=2.5.1+cu121
tqdm~=4.67.1
transformers~=4.47.1 transformers~=4.47.1
xxhash xxhash
numpy~=2.2.0
aiofiles~=24.1.0
pydantic~=2.10.4
python-dotenv~=1.0.1
psycopg-pool~=3.2.4
tqdm~=4.67.1
asyncpg~=0.30.0
setuptools~=70.0.0