diff --git a/examples/lightrag_ollama_gremlin_demo.py b/examples/lightrag_ollama_gremlin_demo.py new file mode 100644 index 00000000..35ffece8 --- /dev/null +++ b/examples/lightrag_ollama_gremlin_demo.py @@ -0,0 +1,89 @@ +import asyncio +import inspect +import os + +# Uncomment these lines below to filter out somewhat verbose INFO level +# logging prints (the default loglevel is INFO). +# This has to go before the lightrag imports to work, +# which triggers linting errors, so we keep it commented out: +# import logging +# logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN) + +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_embedding, ollama_model_complete +from lightrag.utils import EmbeddingFunc + +WORKING_DIR = "./dickens_gremlin" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# Gremlin +os.environ["GREMLIN_HOST"] = "localhost" +os.environ["GREMLIN_PORT"] = "8182" +os.environ["GREMLIN_GRAPH"] = "dickens" + +# Creating a non-default source requires manual +# configuration and a restart on the server: use the dafault "g" +os.environ["GREMLIN_TRAVERSE_SOURCE"] = "g" + +# No authorization by default on docker tinkerpop/gremlin-server +os.environ["GREMLIN_USER"] = "" +os.environ["GREMLIN_PASSWORD"] = "" + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="llama3.1:8b", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}}, + embedding_func=EmbeddingFunc( + embedding_dim=768, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="nomic-embed-text", host="http://localhost:11434" + ), + ), + graph_storage="GremlinStorage", +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) + +# stream response +resp = rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), +) + + +async def print_stream(stream): + async for chunk in stream: + print(chunk, end="", flush=True) + + +if inspect.isasyncgen(resp): + asyncio.run(print_stream(resp)) +else: + print(resp) diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py new file mode 100644 index 00000000..3cad6db0 --- /dev/null +++ b/lightrag/kg/gremlin_impl.py @@ -0,0 +1,418 @@ +import asyncio +import inspect +import json +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +from gremlin_python.driver import client, serializer +from gremlin_python.driver.aiohttp.transport import AiohttpTransport +from gremlin_python.driver.protocol import GremlinServerError +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from lightrag.utils import logger + +from ..base import BaseGraphStorage + + +@dataclass +class GremlinStorage(BaseGraphStorage): + @staticmethod + def load_nx_graph(file_name): + print("no preloading of graph with Gremlin in production") + + # Will use this to make sure single quotes are properly escaped + escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'") + + def __init__(self, namespace, global_config, embedding_func): + super().__init__( + namespace=namespace, + global_config=global_config, + embedding_func=embedding_func, + ) + + self._driver = None + self._driver_lock = asyncio.Lock() + + USER = os.environ.get("GREMLIN_USER", "") + PASSWORD = os.environ.get("GREMLIN_PASSWORD", "") + HOST = os.environ["GREMLIN_HOST"] + PORT = int(os.environ["GREMLIN_PORT"]) + + # TraversalSource, a custom one has to be created manually, + # default it "g" + SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g") + + # All vertices will have graph={GRAPH} property, so that we can + # have several logical graphs for one source + GRAPH = GremlinStorage.escape_rx.sub( + r"\1\2'", + os.environ["GREMLIN_GRAPH"].replace("'", r"\'"), + ) + + self.traverse_source_name = SOURCE + self.graph_name = GRAPH + + self._driver = client.Client( + f"ws://{HOST}:{PORT}/gremlin", + SOURCE, + username=USER, + password=PASSWORD, + message_serializer=serializer.GraphSONSerializersV3d0(), + transport_factory=lambda: AiohttpTransport(call_from_event_loop=True), + ) + + def __post_init__(self): + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def close(self): + if self._driver: + self._driver.close() + self._driver = None + + async def __aexit__(self, exc_type, exc, tb): + if self._driver: + self._driver.close() + + async def index_done_callback(self): + print("KG successfully indexed.") + + @staticmethod + def _to_value_map(value: Any) -> str: + """Dump Python dict as Gremlin valueMap""" + json_str = json.dumps(value, ensure_ascii=False, sort_keys=False) + parsed_str = json_str.replace("'", r"\'") + + # walk over the string and replace curly brackets with square brackets + # outside of strings, as well as replace double quotes with single quotes + # and "deescape" double quotes inside of strings + outside_str = True + escaped = False + remove_indices = [] + for i, c in enumerate(parsed_str): + if escaped: + # previous character was an "odd" backslash + escaped = False + if c == '"': + # we want to "deescape" double quotes: store indices to delete + remove_indices.insert(0, i - 1) + elif c == "\\": + escaped = True + elif c == '"': + outside_str = not outside_str + parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :] + elif c == "{" and outside_str: + parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :] + elif c == "}" and outside_str: + parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :] + for idx in remove_indices: + parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :] + return parsed_str + + @staticmethod + def _convert_properties(properties: Dict[str, Any]) -> str: + """Create chained .property() commands from properties dict""" + props = [] + for k, v in properties.items(): + prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'")) + props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})") + return "".join(props) + + @staticmethod + def _fix_label(label: str) -> str: + """Strip double quotes and make sure single quotes are escaped""" + label = label.strip('"').replace("'", r"\'") + label = GremlinStorage.escape_rx.sub(r"\1\2'", label) + + return label + + async def _query(self, query: str) -> List[Dict[str, Any]]: + """ + Query the Gremlin graph + + Args: + query (str): a query to be executed + + Returns: + List[Dict[str, Any]]: a list of dictionaries containing the result set + """ + + result = list(await asyncio.wrap_future(self._driver.submit_async(query))) + + return result + + async def has_node(self, node_id: str) -> bool: + entity_name_label = GremlinStorage._fix_label(node_id) + + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label}') + .limit(1) + .hasNext() + """ + result = await self._query(query) + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + result[0][0], + ) + + return result[0][0] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + entity_name_label_source = GremlinStorage._fix_label(source_node_id) + entity_name_label_target = GremlinStorage._fix_label(target_node_id) + + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label_source}') + .bothE() + .otherV().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label_target}') + .limit(1) + .hasNext() + """ + result = await self._query(query) + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + result[0][0], + ) + + return result[0][0] + + async def get_node(self, node_id: str) -> Union[dict, None]: + entity_name_label = GremlinStorage._fix_label(node_id) + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label}') + .limit(1) + .project('properties') + .by(elementMap()) + """ + result = await self._query(query) + if result: + node = result[0][0] + node_dict = node["properties"] + logger.debug( + "{%s}: query: {%s}, result: {%s}", + inspect.currentframe().f_code.co_name, + query.format, + node_dict, + ) + return node_dict + + async def node_degree(self, node_id: str) -> int: + entity_name_label = GremlinStorage._fix_label(node_id) + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label}') + .outE() + .inV().has('graph', '{self.graph_name}') + .count() + .project('total_edge_count') + .by() + """ + result = await self._query(query) + edge_count = result[0][0]["total_edge_count"] + + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + edge_count, + ) + + return edge_count + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + logger.debug( + "{%s}:query:src_Degree+trg_degree:result:{%s}", + inspect.currentframe().f_code.co_name, + degrees, + ) + return degrees + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + """ + Find all edges between nodes of two given labels + + Args: + source_node_label (str): Label of the source nodes + target_node_label (str): Label of the target nodes + + Returns: + dict|None: Dict of found edge properties, or None of not found + """ + entity_name_label_source = GremlinStorage._fix_label(source_node_id) + entity_name_label_target = GremlinStorage._fix_label(target_node_id) + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label_source}') + .outE() + .inV().has('graph', '{self.graph_name}') + .hasLabel('{entity_name_label_target}') + .limit(1) + .project('edge_properties') + .by(__.bothE().elementMap()) + """ + result = await self._query(query) + if result: + edge_properties = result[0][0]["edge_properties"] + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + edge_properties, + ) + return edge_properties + + async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + """ + Retrieves all edges (relationships) for a particular node identified by its label. + :return: List of tuples containing edge sources and targets + """ + node_label = GremlinStorage._fix_label(source_node_id) + query1 = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{node_label}') + .out().has('graph', '{self.graph_name}') + .project('connected_label') + .by(__.label()) + """ + query2 = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .as('connected') + .out().has('graph', '{self.graph_name}') + .hasLabel('{node_label}') + .project('connected_label') + .by(__.select('connected').label()) + """ + result1, result2 = await asyncio.gather( + self._query(query1), self._query(query2) + ) + edges1 = ( + [(node_label, res["connected_label"]) for res in result1[0]] + if result1 + else [] + ) + edges2 = ( + [(res["connected_label"], node_label) for res in result2[0]] + if result2 + else [] + ) + + return edges1 + edges2 + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((GremlinServerError,)), + ) + async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + """ + Upsert a node in the Gremlin graph. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + label = GremlinStorage._fix_label(node_id) + properties = GremlinStorage._convert_properties(node_data) + + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{label}').fold() + .coalesce( + unfold(), + addV('{label}')) + .property('graph', '{self.graph_name}') + {properties} + """ + + try: + await self._query(query) + logger.debug( + "Upserted node with label '{%s}' and properties: {%s}", + label, + properties, + ) + except Exception as e: + logger.error("Error during upsert: {%s}", e) + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((GremlinServerError,)), + ) + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] + ): + """ + Upsert an edge and its properties between two nodes identified by their labels. + + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge + """ + source_node_label = GremlinStorage._fix_label(source_node_id) + target_node_label = GremlinStorage._fix_label(target_node_id) + edge_properties = GremlinStorage._convert_properties(edge_data) + + query = f""" + {self.traverse_source_name} + .V().has('graph', '{self.graph_name}') + .hasLabel('{source_node_label}').as('source') + .V().has('graph', '{self.graph_name}') + .hasLabel('{target_node_label}').as('target') + .coalesce( + select('source').outE('DIRECTED').where(inV().as('target')), + select('source').addE('DIRECTED').to(select('target')) + ) + .property('graph', '{self.graph_name}') + {edge_properties} + """ + try: + await self._query(query) + logger.debug( + "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", + source_node_label, + target_node_label, + edge_properties, + ) + except Exception as e: + logger.error("Error during edge upsert: {%s}", e) + raise + + async def _node2vec_embed(self): + print("Implemented but never called.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2661d4c6..81963328 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -81,6 +81,7 @@ TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage") AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") +GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage") def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @@ -284,6 +285,7 @@ class LightRAG: "OracleGraphStorage": OracleGraphStorage, "AGEStorage": AGEStorage, "TiDBGraphStorage": TiDBGraphStorage, + "GremlinStorage": GremlinStorage, # "ArangoDBStorage": ArangoDBStorage } diff --git a/requirements.txt b/requirements.txt index b2df3253..abcc7f97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ aiohttp # database packages graspologic +gremlinpython hnswlib nano-vectordb neo4j