Fix get_all_labels for PostgreSQL
This commit is contained in:
@@ -653,7 +653,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label: Label of the starting node,* means all nodes
|
node_label: Label of the starting node, * means all nodes
|
||||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||||
|
|
||||||
|
@@ -9,7 +9,6 @@ import configparser
|
|||||||
|
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
import sys
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
@@ -28,11 +27,6 @@ from ..base import (
|
|||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
|
||||||
import asyncio.windows_events
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("asyncpg"):
|
if not pm.is_installed("asyncpg"):
|
||||||
@@ -41,6 +35,9 @@ if not pm.is_installed("asyncpg"):
|
|||||||
import asyncpg # type: ignore
|
import asyncpg # type: ignore
|
||||||
from asyncpg import Pool # type: ignore
|
from asyncpg import Pool # type: ignore
|
||||||
|
|
||||||
|
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||||
|
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLDB:
|
class PostgreSQLDB:
|
||||||
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
||||||
@@ -1535,14 +1532,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
MATCH (n:base)
|
MATCH (n:base)
|
||||||
WHERE n.entity_id IS NOT NULL
|
WHERE n.entity_id IS NOT NULL
|
||||||
RETURN DISTINCT n.entity_id AS label
|
RETURN DISTINCT n.entity_id AS label
|
||||||
ORDER BY label
|
ORDER BY n.entity_id
|
||||||
$$) AS (label text)"""
|
$$) AS (label text)"""
|
||||||
% self.graph_name
|
% self.graph_name
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
labels = [self._decode_graph_label(result["label"]) for result in results]
|
labels = [result["label"] for result in results]
|
||||||
|
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
async def embed_nodes(
|
async def embed_nodes(
|
||||||
|
Reference in New Issue
Block a user