GremlinStorage: fix linting error, use asyncio.gather in get_node_edges()

This commit is contained in:
Alex Potapenko
2024-12-20 09:57:35 +01:00
parent 6f71293c83
commit 016d9f572d
2 changed files with 14 additions and 10 deletions

View File

@@ -1,9 +1,13 @@
import asyncio import asyncio
import inspect import inspect
import logging
import os import os
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN) # Uncomment these lines below to filter out somewhat verbose INFO level
# logging prints (the default loglevel is INFO).
# This has to go before the lightrag imports to work,
# which triggers linting errors, so we keep it commented out:
# import logging
# logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_embedding, ollama_model_complete from lightrag.llm import ollama_embedding, ollama_model_complete

View File

@@ -306,13 +306,6 @@ class GremlinStorage(BaseGraphStorage):
.project('connected_label') .project('connected_label')
.by(__.label()) .by(__.label())
""" """
result1 = await self._query(query1)
edges1 = (
[(node_label, res["connected_label"]) for res in result1[0]]
if result1
else []
)
query2 = f""" query2 = f"""
{self.traverse_source_name} {self.traverse_source_name}
.V().has('graph', '{self.graph_name}') .V().has('graph', '{self.graph_name}')
@@ -322,7 +315,14 @@ class GremlinStorage(BaseGraphStorage):
.project('connected_label') .project('connected_label')
.by(__.select('connected').label()) .by(__.select('connected').label())
""" """
result2 = await self._query(query2) result1, result2 = await asyncio.gather(
self._query(query1), self._query(query2)
)
edges1 = (
[(node_label, res["connected_label"]) for res in result1[0]]
if result1
else []
)
edges2 = ( edges2 = (
[(res["connected_label"], node_label) for res in result2[0]] [(res["connected_label"], node_label) for res in result2[0]]
if result2 if result2