Fix the lint issue
This commit is contained in:
@@ -53,7 +53,7 @@ async def main():
|
|||||||
kv_storage="PGKVStorage",
|
kv_storage="PGKVStorage",
|
||||||
doc_status_storage="PGDocStatusStorage",
|
doc_status_storage="PGDocStatusStorage",
|
||||||
graph_storage="PGGraphStorage",
|
graph_storage="PGGraphStorage",
|
||||||
vector_storage="PGVectorStorage"
|
vector_storage="PGVectorStorage",
|
||||||
)
|
)
|
||||||
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||||
rag.doc_status.db = postgres_db
|
rag.doc_status.db = postgres_db
|
||||||
@@ -77,27 +77,35 @@ async def main():
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(
|
print(
|
||||||
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(f"Naive Query Time: {time.time() - start_time} seconds")
|
print(f"Naive Query Time: {time.time() - start_time} seconds")
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print("**** Start Local Query ****")
|
print("**** Start Local Query ****")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
print(
|
print(
|
||||||
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="local"))
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(f"Local Query Time: {time.time() - start_time} seconds")
|
print(f"Local Query Time: {time.time() - start_time} seconds")
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print("**** Start Global Query ****")
|
print("**** Start Global Query ****")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
print(
|
print(
|
||||||
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="global"))
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(f"Global Query Time: {time.time() - start_time}")
|
print(f"Global Query Time: {time.time() - start_time}")
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print("**** Start Hybrid Query ****")
|
print("**** Start Hybrid Query ****")
|
||||||
print(
|
print(
|
||||||
await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
|
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
|
||||||
|
|
||||||
|
@@ -19,7 +19,11 @@ from tenacity import (
|
|||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
BaseVectorStorage, DocStatusStorage, DocStatus, DocProcessingStatus, BaseGraphStorage,
|
BaseVectorStorage,
|
||||||
|
DocStatusStorage,
|
||||||
|
DocStatus,
|
||||||
|
DocProcessingStatus,
|
||||||
|
BaseGraphStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
@@ -36,14 +40,15 @@ class PostgreSQLDB:
|
|||||||
self.user = config.get("user", "postgres")
|
self.user = config.get("user", "postgres")
|
||||||
self.password = config.get("password", None)
|
self.password = config.get("password", None)
|
||||||
self.database = config.get("database", "postgres")
|
self.database = config.get("database", "postgres")
|
||||||
self.workspace = config.get("workspace", 'default')
|
self.workspace = config.get("workspace", "default")
|
||||||
self.max = 12
|
self.max = 12
|
||||||
self.increment = 1
|
self.increment = 1
|
||||||
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
|
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
|
||||||
|
|
||||||
if self.user is None or self.password is None or self.database is None:
|
if self.user is None or self.password is None or self.database is None:
|
||||||
raise ValueError("Missing database user, password, or database in addon_params")
|
raise ValueError(
|
||||||
|
"Missing database user, password, or database in addon_params"
|
||||||
|
)
|
||||||
|
|
||||||
async def initdb(self):
|
async def initdb(self):
|
||||||
try:
|
try:
|
||||||
@@ -54,12 +59,16 @@ class PostgreSQLDB:
|
|||||||
host=self.host,
|
host=self.host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
min_size=1,
|
min_size=1,
|
||||||
max_size=self.max
|
max_size=self.max,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}")
|
logger.info(
|
||||||
|
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}")
|
logger.error(
|
||||||
|
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
||||||
|
)
|
||||||
logger.error(f"PostgreSQL database error: {e}")
|
logger.error(f"PostgreSQL database error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -79,9 +88,13 @@ class PostgreSQLDB:
|
|||||||
|
|
||||||
logger.info("Finished checking all tables in PostgreSQL database")
|
logger.info("Finished checking all tables in PostgreSQL database")
|
||||||
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
|
self,
|
||||||
|
sql: str,
|
||||||
|
params: dict = None,
|
||||||
|
multirows: bool = False,
|
||||||
|
for_age: bool = False,
|
||||||
|
graph_name: str = None,
|
||||||
) -> Union[dict, None, list[dict]]:
|
) -> Union[dict, None, list[dict]]:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
try:
|
try:
|
||||||
@@ -111,7 +124,13 @@ class PostgreSQLDB:
|
|||||||
print(params)
|
print(params)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
|
async def execute(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
data: Union[list, dict] = None,
|
||||||
|
for_age: bool = False,
|
||||||
|
graph_name: str = None,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
if for_age:
|
if for_age:
|
||||||
@@ -130,7 +149,7 @@ class PostgreSQLDB:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
||||||
try:
|
try:
|
||||||
await conn.execute(f'SET search_path = ag_catalog, "$user", public')
|
await conn.execute('SET search_path = ag_catalog, "$user", public')
|
||||||
await conn.execute(f"""select create_graph('{graph_name}')""")
|
await conn.execute(f"""select create_graph('{graph_name}')""")
|
||||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||||
pass
|
pass
|
||||||
@@ -191,7 +210,8 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||||
table_name=NAMESPACE_TABLE_MAP[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
table_name=NAMESPACE_TABLE_MAP[self.namespace],
|
||||||
|
ids=",".join([f"'{id}'" for id in keys]),
|
||||||
)
|
)
|
||||||
params = {"workspace": self.db.workspace}
|
params = {"workspace": self.db.workspace}
|
||||||
try:
|
try:
|
||||||
@@ -207,7 +227,6 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
print(sql)
|
print(sql)
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
|
||||||
################ INSERT METHODS ################
|
################ INSERT METHODS ################
|
||||||
async def upsert(self, data: Dict[str, dict]):
|
async def upsert(self, data: Dict[str, dict]):
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
@@ -282,6 +301,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
def _upsert_relationships(self, item: dict):
|
def _upsert_relationships(self, item: dict):
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||||
data = {
|
data = {
|
||||||
@@ -340,8 +360,6 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
await self.db.execute(upsert_sql, data)
|
await self.db.execute(upsert_sql, data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
logger.info("vector data had been saved into postgresql db!")
|
logger.info("vector data had been saved into postgresql db!")
|
||||||
|
|
||||||
@@ -361,9 +379,11 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
results = await self.db.query(sql, params=params, multirows=True)
|
results = await self.db.query(sql, params=params, multirows=True)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGDocStatusStorage(DocStatusStorage):
|
class PGDocStatusStorage(DocStatusStorage):
|
||||||
"""PostgreSQL implementation of document status storage"""
|
"""PostgreSQL implementation of document status storage"""
|
||||||
|
|
||||||
db: PostgreSQLDB = None
|
db: PostgreSQLDB = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -372,41 +392,47 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
|
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
|
||||||
result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
|
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
||||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||||
if result is None:
|
if result is None:
|
||||||
return set(data)
|
return set(data)
|
||||||
else:
|
else:
|
||||||
existed = set([element['id'] for element in result])
|
existed = set([element["id"] for element in result])
|
||||||
return set(data) - existed
|
return set(data) - existed
|
||||||
|
|
||||||
async def get_status_counts(self) -> Dict[str, int]:
|
async def get_status_counts(self) -> Dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
sql = '''SELECT status as "status", COUNT(1) as "count"
|
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||||
FROM LIGHTRAG_DOC_STATUS
|
FROM LIGHTRAG_DOC_STATUS
|
||||||
where workspace=$1 GROUP BY STATUS
|
where workspace=$1 GROUP BY STATUS
|
||||||
'''
|
"""
|
||||||
result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
|
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
||||||
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
|
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
|
||||||
counts = {}
|
counts = {}
|
||||||
for doc in result:
|
for doc in result:
|
||||||
counts[doc["status"]] = doc["count"]
|
counts[doc["status"]] = doc["count"]
|
||||||
return counts
|
return counts
|
||||||
|
|
||||||
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
|
async def get_docs_by_status(
|
||||||
|
self, status: DocStatus
|
||||||
|
) -> Dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents by status"""
|
"""Get all documents by status"""
|
||||||
sql = 'select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1'
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
|
||||||
params = {'workspace': self.db.workspace, 'status': status}
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
result = await self.db.query(sql, params, True)
|
result = await self.db.query(sql, params, True)
|
||||||
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
||||||
# Converting to be a dict
|
# Converting to be a dict
|
||||||
return {element["id"]:
|
return {
|
||||||
DocProcessingStatus(content_summary=element["content_summary"],
|
element["id"]: DocProcessingStatus(
|
||||||
|
content_summary=element["content_summary"],
|
||||||
content_length=element["content_length"],
|
content_length=element["content_length"],
|
||||||
status=element["status"],
|
status=element["status"],
|
||||||
created_at=element["created_at"],
|
created_at=element["created_at"],
|
||||||
updated_at=element["updated_at"],
|
updated_at=element["updated_at"],
|
||||||
chunks_count=element["chunks_count"]) for element in result}
|
chunks_count=element["chunks_count"],
|
||||||
|
)
|
||||||
|
for element in result
|
||||||
|
}
|
||||||
|
|
||||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||||
"""Get all failed documents"""
|
"""Get all failed documents"""
|
||||||
@@ -436,14 +462,17 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
updated_at = CURRENT_TIMESTAMP"""
|
updated_at = CURRENT_TIMESTAMP"""
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
# chunks_count is optional
|
# chunks_count is optional
|
||||||
await self.db.execute(sql, {
|
await self.db.execute(
|
||||||
|
sql,
|
||||||
|
{
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"id": k,
|
"id": k,
|
||||||
"content_summary": v["content_summary"],
|
"content_summary": v["content_summary"],
|
||||||
"content_length": v["content_length"],
|
"content_length": v["content_length"],
|
||||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||||
"status": v["status"],
|
"status": v["status"],
|
||||||
})
|
},
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -484,7 +513,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
print("KG successfully indexed.")
|
print("KG successfully indexed.")
|
||||||
|
|
||||||
@@ -669,7 +697,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# get pgsql formatted field names
|
# get pgsql formatted field names
|
||||||
fields = [
|
fields = [
|
||||||
PGGraphStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
|
PGGraphStorage._get_col_name(field, idx)
|
||||||
|
for idx, field in enumerate(fields)
|
||||||
]
|
]
|
||||||
|
|
||||||
# build resulting pgsql relation
|
# build resulting pgsql relation
|
||||||
@@ -690,7 +719,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
projection=select_str,
|
projection=select_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
|
async def _query(
|
||||||
|
self, query: str, readonly=True, upsert_edge=False, **params: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Query the graph by taking a cypher query, converting it to an
|
Query the graph by taking a cypher query, converting it to an
|
||||||
age compatible query, executing it and converting the result
|
age compatible query, executing it and converting the result
|
||||||
@@ -708,14 +739,25 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# execute the query, rolling back on an error
|
# execute the query, rolling back on an error
|
||||||
try:
|
try:
|
||||||
if readonly:
|
if readonly:
|
||||||
data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
|
data = await self.db.query(
|
||||||
|
wrapped_query,
|
||||||
|
multirows=True,
|
||||||
|
for_age=True,
|
||||||
|
graph_name=self.graph_name,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
||||||
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
||||||
if upsert_edge:
|
if upsert_edge:
|
||||||
data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
|
data = await self.db.execute(
|
||||||
|
f"{wrapped_query};{wrapped_query};",
|
||||||
|
for_age=True,
|
||||||
|
graph_name=self.graph_name,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
|
data = await self.db.execute(
|
||||||
|
wrapped_query, for_age=True, graph_name=self.graph_name
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise PGGraphQueryException(
|
raise PGGraphQueryException(
|
||||||
{
|
{
|
||||||
@@ -935,7 +977,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
source_node_label = source_node_id.strip('"')
|
source_node_label = source_node_id.strip('"')
|
||||||
target_node_label = target_node_id.strip('"')
|
target_node_label = target_node_id.strip('"')
|
||||||
edge_properties = edge_data
|
edge_properties = edge_data
|
||||||
logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
|
logger.info(
|
||||||
|
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
|
||||||
|
)
|
||||||
|
|
||||||
query = """MATCH (source:`{src_label}`)
|
query = """MATCH (source:`{src_label}`)
|
||||||
WITH source
|
WITH source
|
||||||
@@ -1056,7 +1100,6 @@ TABLES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SQL_TEMPLATES = {
|
SQL_TEMPLATES = {
|
||||||
# SQL for KVStorage
|
# SQL for KVStorage
|
||||||
"get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
|
"get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
|
||||||
@@ -1136,5 +1179,5 @@ SQL_TEMPLATES = {
|
|||||||
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
||||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||||
"""
|
""",
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import asyncpg
|
import asyncpg
|
||||||
import sys, os
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
import psycopg
|
import psycopg
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
@@ -15,19 +16,24 @@ os.environ["AGE_GRAPH_NAME"] = "dickens"
|
|||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
import asyncio.windows_events
|
import asyncio.windows_events
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
async def get_pool():
|
async def get_pool():
|
||||||
return await asyncpg.create_pool(
|
return await asyncpg.create_pool(
|
||||||
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
||||||
min_size=10,
|
min_size=10,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
max_queries=5000,
|
max_queries=5000,
|
||||||
max_inactive_connection_lifetime=300.0
|
max_inactive_connection_lifetime=300.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main1():
|
async def main1():
|
||||||
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
connection_string = (
|
||||||
|
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
||||||
|
)
|
||||||
pool = AsyncConnectionPool(connection_string, open=False)
|
pool = AsyncConnectionPool(connection_string, open=False)
|
||||||
await pool.open()
|
await pool.open()
|
||||||
|
|
||||||
@@ -36,7 +42,7 @@ async def main1():
|
|||||||
async with conn.cursor() as curs:
|
async with conn.cursor() as curs:
|
||||||
try:
|
try:
|
||||||
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
||||||
await curs.execute(f"SELECT create_graph('dickens-2')")
|
await curs.execute("SELECT create_graph('dickens-2')")
|
||||||
await conn.commit()
|
await conn.commit()
|
||||||
print("create_graph success")
|
print("create_graph success")
|
||||||
except (
|
except (
|
||||||
@@ -48,6 +54,7 @@ async def main1():
|
|||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
db = PostgreSQLDB(
|
db = PostgreSQLDB(
|
||||||
config={
|
config={
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
@@ -58,6 +65,7 @@ db = PostgreSQLDB(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def query_with_age():
|
async def query_with_age():
|
||||||
await db.initdb()
|
await db.initdb()
|
||||||
graph = PGGraphStorage(
|
graph = PGGraphStorage(
|
||||||
@@ -69,6 +77,7 @@ async def query_with_age():
|
|||||||
res = await graph.get_node('"CHRISTMAS-TIME"')
|
res = await graph.get_node('"CHRISTMAS-TIME"')
|
||||||
print("Node is: ", res)
|
print("Node is: ", res)
|
||||||
|
|
||||||
|
|
||||||
async def create_edge_with_age():
|
async def create_edge_with_age():
|
||||||
await db.initdb()
|
await db.initdb()
|
||||||
graph = PGGraphStorage(
|
graph = PGGraphStorage(
|
||||||
@@ -89,31 +98,28 @@ async def create_edge_with_age():
|
|||||||
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"')
|
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
|
||||||
print("Edge is: ", res)
|
print("Edge is: ", res)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
# 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
|
|
||||||
# 专门用来设置一些数据库独有的某些属性
|
|
||||||
# 从池子中取出一个连接
|
|
||||||
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
||||||
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
try:
|
try:
|
||||||
await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""")
|
await conn.execute(
|
||||||
|
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
|
||||||
|
)
|
||||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||||
print("create_graph already exists")
|
print("create_graph already exists")
|
||||||
# stmt = await conn.prepare(sql)
|
# stmt = await conn.prepare(sql)
|
||||||
row = await conn.fetch(sql)
|
row = await conn.fetch(sql)
|
||||||
print("row is: ", row)
|
print("row is: ", row)
|
||||||
|
|
||||||
# 解决办法就是起一个别名
|
|
||||||
row = await conn.fetchrow("select '100'::int + 200 as result")
|
row = await conn.fetchrow("select '100'::int + 200 as result")
|
||||||
print(row) # <Record result=300>
|
print(row) # <Record result=300>
|
||||||
# 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
asyncio.run(query_with_age())
|
asyncio.run(query_with_age())
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
accelerate
|
accelerate
|
||||||
aioboto3~=13.3.0
|
aioboto3~=13.3.0
|
||||||
|
aiofiles~=24.1.0
|
||||||
aiohttp~=3.11.11
|
aiohttp~=3.11.11
|
||||||
|
asyncpg~=0.30.0
|
||||||
|
|
||||||
# database packages
|
# database packages
|
||||||
graspologic
|
graspologic
|
||||||
@@ -9,14 +11,20 @@ hnswlib
|
|||||||
nano-vectordb
|
nano-vectordb
|
||||||
neo4j~=5.27.0
|
neo4j~=5.27.0
|
||||||
networkx~=3.2.1
|
networkx~=3.2.1
|
||||||
|
|
||||||
|
numpy~=2.2.0
|
||||||
ollama~=0.4.4
|
ollama~=0.4.4
|
||||||
openai~=1.58.1
|
openai~=1.58.1
|
||||||
oracledb
|
oracledb
|
||||||
|
psycopg-pool~=3.2.4
|
||||||
psycopg[binary,pool]~=3.2.3
|
psycopg[binary,pool]~=3.2.3
|
||||||
|
pydantic~=2.10.4
|
||||||
pymilvus
|
pymilvus
|
||||||
pymongo
|
pymongo
|
||||||
pymysql
|
pymysql
|
||||||
|
python-dotenv~=1.0.1
|
||||||
pyvis~=0.3.2
|
pyvis~=0.3.2
|
||||||
|
setuptools~=70.0.0
|
||||||
# lmdeploy[all]
|
# lmdeploy[all]
|
||||||
sqlalchemy~=2.0.36
|
sqlalchemy~=2.0.36
|
||||||
tenacity~=9.0.0
|
tenacity~=9.0.0
|
||||||
@@ -25,14 +33,6 @@ tenacity~=9.0.0
|
|||||||
# LLM packages
|
# LLM packages
|
||||||
tiktoken~=0.8.0
|
tiktoken~=0.8.0
|
||||||
torch~=2.5.1+cu121
|
torch~=2.5.1+cu121
|
||||||
|
tqdm~=4.67.1
|
||||||
transformers~=4.47.1
|
transformers~=4.47.1
|
||||||
xxhash
|
xxhash
|
||||||
|
|
||||||
numpy~=2.2.0
|
|
||||||
aiofiles~=24.1.0
|
|
||||||
pydantic~=2.10.4
|
|
||||||
python-dotenv~=1.0.1
|
|
||||||
psycopg-pool~=3.2.4
|
|
||||||
tqdm~=4.67.1
|
|
||||||
asyncpg~=0.30.0
|
|
||||||
setuptools~=70.0.0
|
|
Reference in New Issue
Block a user