From 016d9f572d1d1ee72e6abf0278f6eef5b27376c8 Mon Sep 17 00:00:00 2001 From: Alex Potapenko Date: Fri, 20 Dec 2024 09:57:35 +0100 Subject: [PATCH] GremlinStorage: fix linting error, use asyncio.gather in get_node_edges() --- examples/lightrag_ollama_gremlin_demo.py | 8 ++++++-- lightrag/kg/gremlin_impl.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/lightrag_ollama_gremlin_demo.py b/examples/lightrag_ollama_gremlin_demo.py index aa3bd011..35ffece8 100644 --- a/examples/lightrag_ollama_gremlin_demo.py +++ b/examples/lightrag_ollama_gremlin_demo.py @@ -1,9 +1,13 @@ import asyncio import inspect -import logging import os -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN) +# Uncomment these lines below to filter out somewhat verbose INFO level +# logging prints (the default loglevel is INFO). +# This has to go before the lightrag imports to work, +# which triggers linting errors, so we keep it commented out: +# import logging +# logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN) from lightrag import LightRAG, QueryParam from lightrag.llm import ollama_embedding, ollama_model_complete diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 4912e752..3cad6db0 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -306,13 +306,6 @@ class GremlinStorage(BaseGraphStorage): .project('connected_label') .by(__.label()) """ - result1 = await self._query(query1) - edges1 = ( - [(node_label, res["connected_label"]) for res in result1[0]] - if result1 - else [] - ) - query2 = f""" {self.traverse_source_name} .V().has('graph', '{self.graph_name}') @@ -322,7 +315,14 @@ class GremlinStorage(BaseGraphStorage): .project('connected_label') .by(__.select('connected').label()) """ - result2 = await self._query(query2) + result1, result2 = await asyncio.gather( + self._query(query1), self._query(query2) + ) + edges1 = ( + [(node_label, res["connected_label"]) for res in result1[0]] + if result1 + else [] + ) edges2 = ( [(res["connected_label"], node_label) for res in result2[0]] if result2