From fc3cc40a2e0035ee21a5cf1555822961ddeb3a82 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:07:52 +0100 Subject: [PATCH 01/33] Create json_kv_storage.py --- lightrag/storage/json_kv_storage.py | 149 ++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 lightrag/storage/json_kv_storage.py diff --git a/lightrag/storage/json_kv_storage.py b/lightrag/storage/json_kv_storage.py new file mode 100644 index 00000000..ddb1a863 --- /dev/null +++ b/lightrag/storage/json_kv_storage.py @@ -0,0 +1,149 @@ +""" +JsonDocStatus Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + + +import asyncio +import html +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +from typing import Any, Union, cast, Dict +import numpy as np + +import time + +from lightrag.utils import ( + logger, + load_json, + write_json, + compute_mdhash_id, +) + +from lightrag.base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, + DocStatus, + DocProcessingStatus, + DocStatusStorage, +) + + +@dataclass +class JsonKVStorage(BaseKVStorage): + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + self._lock = asyncio.Lock() + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + + async def all_keys(self) -> list[str]: + return list(self._data.keys()) + + async def index_done_callback(self): + write_json(self._data, self._file_name) + + async def get_by_id(self, id): + return self._data.get(id, None) + + async def get_by_ids(self, ids, fields=None): + if fields is None: + return [self._data.get(id, None) for id in ids] + return [ + ( + {k: v for k, v in self._data[id].items() if k in fields} + if self._data.get(id, None) + else None + ) + for id in ids + ] + + async def filter_keys(self, data: list[str]) -> set[str]: + return set([s for s in data if s not in self._data]) + + async def upsert(self, data: dict[str, dict]): + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) + return left_data + + async def drop(self): + self._data = {} + + async def filter(self, filter_func): + """Filter key-value pairs based on a filter function + + Args: + filter_func: The filter function, which takes a value as an argument and returns a boolean value + + Returns: + Dict: Key-value pairs that meet the condition + """ + result = {} + async with self._lock: + for key, value in self._data.items(): + if filter_func(value): + result[key] = value + return result + + async def delete(self, ids: list[str]): + """Delete data with specified IDs + + Args: + ids: List of IDs to delete + """ + async with self._lock: + for id in ids: + if id in self._data: + del self._data[id] + await self.index_done_callback() + logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") + + From d274beb9d24eb16280e7599a955564ac7447186d Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:08:14 +0100 Subject: [PATCH 02/33] Create jsondocstatus_storage.py --- lightrag/storage/jsondocstatus_storage.py | 139 ++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 lightrag/storage/jsondocstatus_storage.py diff --git a/lightrag/storage/jsondocstatus_storage.py b/lightrag/storage/jsondocstatus_storage.py new file mode 100644 index 00000000..27da40db --- /dev/null +++ b/lightrag/storage/jsondocstatus_storage.py @@ -0,0 +1,139 @@ +""" +JsonDocStatus Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + + +import asyncio +import html +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +from typing import Any, Union, cast, Dict +import numpy as np + +import time + +from lightrag.utils import ( + logger, + load_json, + write_json, + compute_mdhash_id, +) + +from lightrag.base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, + DocStatus, + DocProcessingStatus, + DocStatusStorage, +) + + +@dataclass +class JsonDocStatusStorage(DocStatusStorage): + """JSON implementation of document status storage""" + + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + logger.info(f"Loaded document status storage with {len(self._data)} records") + + async def filter_keys(self, data: list[str]) -> set[str]: + """Return keys that should be processed (not in storage or not successfully processed)""" + return set( + [ + k + for k in data + if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED + ] + ) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + counts = {status: 0 for status in DocStatus} + for doc in self._data.values(): + counts[doc["status"]] += 1 + return counts + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} + + async def index_done_callback(self): + """Save data to file after indexing""" + write_json(self._data, self._file_name) + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + self._data.update(data) + await self.index_done_callback() + return data + + async def get_by_id(self, id: str): + return self._data.get(id) + + async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: + """Get document status by ID""" + return self._data.get(doc_id) + + async def delete(self, doc_ids: list[str]): + """Delete document status by IDs""" + for doc_id in doc_ids: + self._data.pop(doc_id, None) + await self.index_done_callback() From fc2ebd98db5417040ddca43271a10983b87b63f3 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:08:41 +0100 Subject: [PATCH 03/33] Create nano_vector_db.py --- lightrag/storage/nano_vector_db.py | 217 +++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 lightrag/storage/nano_vector_db.py diff --git a/lightrag/storage/nano_vector_db.py b/lightrag/storage/nano_vector_db.py new file mode 100644 index 00000000..1499076c --- /dev/null +++ b/lightrag/storage/nano_vector_db.py @@ -0,0 +1,217 @@ +""" +NanoVectorDB Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + + +import asyncio +import html +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +from typing import Any, Union, cast, Dict +import numpy as np +import pipmaster as pm + +if not pm.is_installed("nano-vectordb"): + pm.install("nano-vectordb") + +from nano_vectordb import NanoVectorDB +import time + +from lightrag.utils import ( + logger, + load_json, + write_json, + compute_mdhash_id, +) + +from lightrag.base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, + DocStatus, + DocProcessingStatus, + DocStatusStorage, +) + + +@dataclass +class NanoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + self._client_file_name = os.path.join( + self.global_config["working_dir"], f"vdb_{self.namespace}.json" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + + current_time = time.time() + list_data = [ + { + "__id__": k, + "__created_at__": current_time, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + if len(embeddings) == len(list_data): + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + results = self._client.upsert(datas=list_data) + return results + else: + # sometimes the embedding is not returned correctly. just log it. + logger.error( + f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" + ) + + async def query(self, query: str, top_k=5): + embedding = await self.embedding_func([query]) + embedding = embedding[0] + results = self._client.query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] + return results + + @property + def client_storage(self): + return getattr(self._client, "_NanoVectorDB__storage") + + async def delete(self, ids: list[str]): + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + self._client.delete(ids) + logger.info( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + + async def delete_entity(self, entity_name: str): + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Check if the entity exists + if self._client.get([entity_id]): + await self.delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str): + try: + relations = [ + dp + for dp in self.client_storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + ids_to_delete = [relation["__id__"] for relation in relations] + + if ids_to_delete: + await self.delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def index_done_callback(self): + self._client.save() From dd94b0026ab2cf20b2ec249356238898f519118d Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:09:08 +0100 Subject: [PATCH 04/33] Create networkx_storage.py --- lightrag/storage/networkx_storage.py | 229 +++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 lightrag/storage/networkx_storage.py diff --git a/lightrag/storage/networkx_storage.py b/lightrag/storage/networkx_storage.py new file mode 100644 index 00000000..3969157e --- /dev/null +++ b/lightrag/storage/networkx_storage.py @@ -0,0 +1,229 @@ +""" +NetworkX Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + + +import html +import os +from dataclasses import dataclass +from typing import Any, Union, cast +import networkx as nx +import numpy as np + + +from lightrag.utils import ( + logger, +) + +from lightrag.base import ( + BaseGraphStorage, +) + + +@dataclass +class NetworkXStorage(BaseGraphStorage): + @staticmethod + def load_nx_graph(file_name) -> nx.Graph: + if os.path.exists(file_name): + return nx.read_graphml(file_name) + return None + + @staticmethod + def write_nx_graph(graph: nx.Graph, file_name): + logger.info( + f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" + ) + nx.write_graphml(graph, file_name) + + @staticmethod + def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + Return the largest connected component of the graph, with nodes and edges sorted in a stable way. + """ + from graspologic.utils import largest_connected_component + + graph = graph.copy() + graph = cast(nx.Graph, largest_connected_component(graph)) + node_mapping = { + node: html.unescape(node.upper().strip()) for node in graph.nodes() + } # type: ignore + graph = nx.relabel_nodes(graph, node_mapping) + return NetworkXStorage._stabilize_graph(graph) + + @staticmethod + def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + Ensure an undirected graph with the same relationships will always be read the same way. + """ + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + + sorted_nodes = graph.nodes(data=True) + sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + + fixed_graph.add_nodes_from(sorted_nodes) + edges = list(graph.edges(data=True)) + + if not graph.is_directed(): + + def _sort_source_target(edge): + source, target, edge_data = edge + if source > target: + temp = source + source = target + target = temp + return source, target, edge_data + + edges = [_sort_source_target(edge) for edge in edges] + + def _get_edge_key(source: Any, target: Any) -> str: + return f"{source} -> {target}" + + edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) + + fixed_graph.add_edges_from(edges) + return fixed_graph + + def __post_init__(self): + self._graphml_xml_file = os.path.join( + self.global_config["working_dir"], f"graph_{self.namespace}.graphml" + ) + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + self._graph = preloaded_graph or nx.Graph() + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def index_done_callback(self): + NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) + + async def has_node(self, node_id: str) -> bool: + return self._graph.has_node(node_id) + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + return self._graph.has_edge(source_node_id, target_node_id) + + async def get_node(self, node_id: str) -> Union[dict, None]: + return self._graph.nodes.get(node_id) + + async def node_degree(self, node_id: str) -> int: + return self._graph.degree(node_id) + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + return self._graph.degree(src_id) + self._graph.degree(tgt_id) + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + return self._graph.edges.get((source_node_id, target_node_id)) + + async def get_node_edges(self, source_node_id: str): + if self._graph.has_node(source_node_id): + return list(self._graph.edges(source_node_id)) + return None + + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + self._graph.add_node(node_id, **node_data) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + self._graph.add_edge(source_node_id, target_node_id, **edge_data) + + async def delete_node(self, node_id: str): + """ + Delete a node from the graph based on the specified node_id. + + :param node_id: The node_id to delete + """ + if self._graph.has_node(node_id): + self._graph.remove_node(node_id) + logger.info(f"Node {node_id} deleted from the graph.") + else: + logger.warning(f"Node {node_id} not found in the graph for deletion.") + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() + + # @TODO: NOT USED + async def _node2vec_embed(self): + from graspologic import embed + + embeddings, nodes = embed.node2vec_embed( + self._graph, + **self.global_config["node2vec_params"], + ) + + nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + return embeddings, nodes_ids + + def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node IDs to be deleted + """ + for node in nodes: + if self._graph.has_node(node): + self._graph.remove_node(node) + + def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + if self._graph.has_edge(source, target): + self._graph.remove_edge(source, target) From f9fdd4cb35bda2f6187abe4d06582bc6dbccc7d8 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:27:08 +0100 Subject: [PATCH 05/33] Create __init__.py --- lightrag/storage/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 lightrag/storage/__init__.py diff --git a/lightrag/storage/__init__.py b/lightrag/storage/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/lightrag/storage/__init__.py @@ -0,0 +1 @@ + From 4911f936e0d712601e1f2f9a93ba1e5acc59d4a1 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:29:09 +0100 Subject: [PATCH 06/33] Update json_kv_storage.py --- lightrag/storage/json_kv_storage.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/lightrag/storage/json_kv_storage.py b/lightrag/storage/json_kv_storage.py index ddb1a863..57fe765d 100644 --- a/lightrag/storage/json_kv_storage.py +++ b/lightrag/storage/json_kv_storage.py @@ -50,29 +50,17 @@ Usage: import asyncio -import html import os -from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass -from typing import Any, Union, cast, Dict -import numpy as np - -import time from lightrag.utils import ( logger, load_json, write_json, - compute_mdhash_id, ) from lightrag.base import ( - BaseGraphStorage, BaseKVStorage, - BaseVectorStorage, - DocStatus, - DocProcessingStatus, - DocStatusStorage, ) From befcb8269308a238b571b00c3421f59e9c522d66 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:30:39 +0100 Subject: [PATCH 07/33] Update jsondocstatus_storage.py --- lightrag/storage/jsondocstatus_storage.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/lightrag/storage/jsondocstatus_storage.py b/lightrag/storage/jsondocstatus_storage.py index 27da40db..8f326170 100644 --- a/lightrag/storage/jsondocstatus_storage.py +++ b/lightrag/storage/jsondocstatus_storage.py @@ -48,28 +48,17 @@ Usage: """ - -import asyncio -import html import os -from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass -from typing import Any, Union, cast, Dict -import numpy as np - -import time +from typing import Union, Dict from lightrag.utils import ( logger, load_json, write_json, - compute_mdhash_id, ) from lightrag.base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, DocStatus, DocProcessingStatus, DocStatusStorage, From 3fe780893a8db444698904b0184f73844217e3bd Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:31:23 +0100 Subject: [PATCH 08/33] Update nano_vector_db.py --- lightrag/storage/nano_vector_db.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/lightrag/storage/nano_vector_db.py b/lightrag/storage/nano_vector_db.py index 1499076c..f2372799 100644 --- a/lightrag/storage/nano_vector_db.py +++ b/lightrag/storage/nano_vector_db.py @@ -47,14 +47,10 @@ Usage: from lightrag.storage.networkx_storage import NetworkXStorage """ - - import asyncio -import html import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass -from typing import Any, Union, cast, Dict import numpy as np import pipmaster as pm @@ -66,18 +62,11 @@ import time from lightrag.utils import ( logger, - load_json, - write_json, compute_mdhash_id, ) from lightrag.base import ( - BaseGraphStorage, - BaseKVStorage, BaseVectorStorage, - DocStatus, - DocProcessingStatus, - DocStatusStorage, ) From 3c72299e3868725cf576158949997aee41e7a60f Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:31:46 +0100 Subject: [PATCH 09/33] Update networkx_storage.py --- lightrag/storage/networkx_storage.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightrag/storage/networkx_storage.py b/lightrag/storage/networkx_storage.py index 3969157e..493c551e 100644 --- a/lightrag/storage/networkx_storage.py +++ b/lightrag/storage/networkx_storage.py @@ -47,8 +47,6 @@ Usage: from lightrag.storage.networkx_storage import NetworkXStorage """ - - import html import os from dataclasses import dataclass From 609d060579206586fae9ca635e8cfc009ecf5464 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:33:04 +0100 Subject: [PATCH 10/33] Update requirements.txt --- requirements.txt | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/requirements.txt b/requirements.txt index c372cf9b..0f4c18ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,37 +1,24 @@ accelerate aiofiles aiohttp -asyncpg configparser # database packages -graspologic -gremlinpython -nano-vectordb -neo4j networkx +graspologic -# TODO : Remove specific databases and move the installation to their corresponding files -# Use pipmaster for install if needed +# Basic modules numpy -oracledb pipmaster -psycopg-pool -psycopg[binary,pool] pydantic -pymilvus -pymongo -pymysql - +# File manipulation libraries PyPDF2 python-docx python-dotenv python-pptx -pyvis -redis + setuptools -sqlalchemy tenacity @@ -39,3 +26,5 @@ tenacity tiktoken tqdm xxhash + +# Extra libraries are installed when needed using pipmaster From 2f19ac36258aafaeb28ea386987f934518bba0ea Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:33:35 +0100 Subject: [PATCH 11/33] Update graph_visual_with_html.py --- examples/graph_visual_with_html.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index 56642185..d082a170 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -1,4 +1,8 @@ import networkx as nx +import pipmaster as pm +if not pm.is_installed("pyvis"): + pm.install("pyvis") + from pyvis.network import Network import random From 6d95f58f3497d66c70da835a5c17491adaf6f817 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:34:00 +0100 Subject: [PATCH 12/33] Update lightrag.py --- lightrag/lightrag.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7e8a3bb7..9a849921 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -38,10 +38,10 @@ from .base import ( from .prompt import GRAPH_FIELD_SEP STORAGES = { - "JsonKVStorage": ".storage", - "NanoVectorDBStorage": ".storage", - "NetworkXStorage": ".storage", - "JsonDocStatusStorage": ".storage", + "NetworkXStorage": ".storage.networkx_storage", + "JsonKVStorage": ".storage.json_kv_storage", + "NanoVectorDBStorage": ".storage.nano_vector_db", + "JsonDocStatusStorage": ".storage.jsondocstatus_storage", "Neo4JStorage": ".kg.neo4j_impl", "OracleKVStorage": ".kg.oracle_impl", "OracleGraphStorage": ".kg.oracle_impl", From ce887618590847bc371237e2eb329e0638b147cd Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:34:23 +0100 Subject: [PATCH 13/33] Update storage.py --- lightrag/storage.py | 461 +------------------------------------------- 1 file changed, 1 insertion(+), 460 deletions(-) diff --git a/lightrag/storage.py b/lightrag/storage.py index 3bee911b..91ba7bcc 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -1,460 +1 @@ -import asyncio -import html -import os -from tqdm.asyncio import tqdm as tqdm_async -from dataclasses import dataclass -from typing import Any, Union, cast, Dict -import networkx as nx -import numpy as np - -from nano_vectordb import NanoVectorDB -import time - -from .utils import ( - logger, - load_json, - write_json, - compute_mdhash_id, -) - -from .base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, - DocStatus, - DocProcessingStatus, - DocStatusStorage, -) - - -@dataclass -class JsonKVStorage(BaseKVStorage): - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - self._lock = asyncio.Lock() - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - - async def all_keys(self) -> list[str]: - return list(self._data.keys()) - - async def index_done_callback(self): - write_json(self._data, self._file_name) - - async def get_by_id(self, id): - return self._data.get(id, None) - - async def get_by_ids(self, ids, fields=None): - if fields is None: - return [self._data.get(id, None) for id in ids] - return [ - ( - {k: v for k, v in self._data[id].items() if k in fields} - if self._data.get(id, None) - else None - ) - for id in ids - ] - - async def filter_keys(self, data: list[str]) -> set[str]: - return set([s for s in data if s not in self._data]) - - async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - return left_data - - async def drop(self): - self._data = {} - - async def filter(self, filter_func): - """Filter key-value pairs based on a filter function - - Args: - filter_func: The filter function, which takes a value as an argument and returns a boolean value - - Returns: - Dict: Key-value pairs that meet the condition - """ - result = {} - async with self._lock: - for key, value in self._data.items(): - if filter_func(value): - result[key] = value - return result - - async def delete(self, ids: list[str]): - """Delete data with specified IDs - - Args: - ids: List of IDs to delete - """ - async with self._lock: - for id in ids: - if id in self._data: - del self._data[id] - await self.index_done_callback() - logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") - - -@dataclass -class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 - - def __post_init__(self): - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) - self._max_batch_size = self.global_config["embedding_batch_num"] - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) - self.cosine_better_than_threshold = self.global_config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) - - async def upsert(self, data: dict[str, dict]): - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not len(data): - logger.warning("You insert an empty data to vector DB") - return [] - - current_time = time.time() - list_data = [ - { - "__id__": k, - "__created_at__": current_time, - **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - - async def wrapped_task(batch): - result = await self.embedding_func(batch) - pbar.update(1) - return result - - embedding_tasks = [wrapped_task(batch) for batch in batches] - pbar = tqdm_async( - total=len(embedding_tasks), desc="Generating embeddings", unit="batch" - ) - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - if len(embeddings) == len(list_data): - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) - return results - else: - # sometimes the embedding is not returned correctly. just log it. - logger.error( - f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" - ) - - async def query(self, query: str, top_k=5): - embedding = await self.embedding_func([query]) - embedding = embedding[0] - results = self._client.query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] - return results - - @property - def client_storage(self): - return getattr(self._client, "_NanoVectorDB__storage") - - async def delete(self, ids: list[str]): - """Delete vectors with specified IDs - - Args: - ids: List of vector IDs to be deleted - """ - try: - self._client.delete(ids) - logger.info( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - - async def delete_entity(self, entity_name: str): - try: - entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" - ) - # Check if the entity exists - if self._client.get([entity_id]): - await self.delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - - async def delete_entity_relation(self, entity_name: str): - try: - relations = [ - dp - for dp in self.client_storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") - ids_to_delete = [relation["__id__"] for relation in relations] - - if ids_to_delete: - await self.delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for {entity_name}: {e}") - - async def index_done_callback(self): - self._client.save() - - -@dataclass -class NetworkXStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name) -> nx.Graph: - if os.path.exists(file_name): - return nx.read_graphml(file_name) - return None - - @staticmethod - def write_nx_graph(graph: nx.Graph, file_name): - logger.info( - f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" - ) - nx.write_graphml(graph, file_name) - - @staticmethod - def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Return the largest connected component of the graph, with nodes and edges sorted in a stable way. - """ - from graspologic.utils import largest_connected_component - - graph = graph.copy() - graph = cast(nx.Graph, largest_connected_component(graph)) - node_mapping = { - node: html.unescape(node.upper().strip()) for node in graph.nodes() - } # type: ignore - graph = nx.relabel_nodes(graph, node_mapping) - return NetworkXStorage._stabilize_graph(graph) - - @staticmethod - def _stabilize_graph(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Ensure an undirected graph with the same relationships will always be read the same way. - """ - fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() - - sorted_nodes = graph.nodes(data=True) - sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) - - fixed_graph.add_nodes_from(sorted_nodes) - edges = list(graph.edges(data=True)) - - if not graph.is_directed(): - - def _sort_source_target(edge): - source, target, edge_data = edge - if source > target: - temp = source - source = target - target = temp - return source, target, edge_data - - edges = [_sort_source_target(edge) for edge in edges] - - def _get_edge_key(source: Any, target: Any) -> str: - return f"{source} -> {target}" - - edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) - - fixed_graph.add_edges_from(edges) - return fixed_graph - - def __post_init__(self): - self._graphml_xml_file = os.path.join( - self.global_config["working_dir"], f"graph_{self.namespace}.graphml" - ) - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - self._graph = preloaded_graph or nx.Graph() - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } - - async def index_done_callback(self): - NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) - - async def has_node(self, node_id: str) -> bool: - return self._graph.has_node(node_id) - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._graph.has_edge(source_node_id, target_node_id) - - async def get_node(self, node_id: str) -> Union[dict, None]: - return self._graph.nodes.get(node_id) - - async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - return self._graph.edges.get((source_node_id, target_node_id)) - - async def get_node_edges(self, source_node_id: str): - if self._graph.has_node(source_node_id): - return list(self._graph.edges(source_node_id)) - return None - - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - self._graph.add_node(node_id, **node_data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): - self._graph.add_edge(source_node_id, target_node_id, **edge_data) - - async def delete_node(self, node_id: str): - """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete - """ - if self._graph.has_node(node_id): - self._graph.remove_node(node_id) - logger.info(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") - - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - if algorithm not in self._node_embed_algorithms: - raise ValueError(f"Node embedding algorithm {algorithm} not supported") - return await self._node_embed_algorithms[algorithm]() - - # @TODO: NOT USED - async def _node2vec_embed(self): - from graspologic import embed - - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.global_config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] - return embeddings, nodes_ids - - def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node IDs to be deleted - """ - for node in nodes: - if self._graph.has_node(node): - self._graph.remove_node(node) - - def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - if self._graph.has_edge(source, target): - self._graph.remove_edge(source, target) - - -@dataclass -class JsonDocStatusStorage(DocStatusStorage): - """JSON implementation of document status storage""" - - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - logger.info(f"Loaded document status storage with {len(self._data)} records") - - async def filter_keys(self, data: list[str]) -> set[str]: - """Return keys that should be processed (not in storage or not successfully processed)""" - return set( - [ - k - for k in data - if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED - ] - ) - - async def get_status_counts(self) -> Dict[str, int]: - """Get counts of documents in each status""" - counts = {status: 0 for status in DocStatus} - for doc in self._data.values(): - counts[doc["status"]] += 1 - return counts - - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} - - async def index_done_callback(self): - """Save data to file after indexing""" - write_json(self._data, self._file_name) - - async def upsert(self, data: dict[str, dict]): - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ - self._data.update(data) - await self.index_done_callback() - return data - - async def get_by_id(self, id: str): - return self._data.get(id) - - async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: - """Get document status by ID""" - return self._data.get(doc_id) - - async def delete(self, doc_ids: list[str]): - """Delete document status by IDs""" - for doc_id in doc_ids: - self._data.pop(doc_id, None) - await self.index_done_callback() +# This file is not needed anymore (TODO: remove) From 09c55032bdc9e6c2c6c9f51225b8ea4324894589 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:34:43 +0100 Subject: [PATCH 14/33] Update requirements.txt --- lightrag/api/requirements.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index fc5afd58..7b2593c0 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -1,10 +1,7 @@ ascii_colors fastapi -nano_vectordb nest_asyncio numpy -ollama -openai pipmaster python-dotenv python-multipart @@ -12,5 +9,4 @@ tenacity tiktoken torch tqdm -transformers uvicorn From af245eb73eb88dab3ed6c52af575fe285e21709c Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:35:08 +0100 Subject: [PATCH 15/33] Update age_impl.py --- lightrag/kg/age_impl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 275f5775..df32b7cb 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -6,6 +6,14 @@ import sys from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +import pipmaster as pm + +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + import psycopg from psycopg.rows import namedtuple_row From 9390abb49b808e21d9478fe5de7e589782ea0488 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:35:26 +0100 Subject: [PATCH 16/33] Update milvus_impl.py --- lightrag/kg/milvus_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index bf20ffd7..905a08b5 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -6,6 +6,9 @@ import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage +import pipmaster as pm +if not pm.is_installed("pymilvus"): + pm.install("pymilvus") from pymilvus import MilvusClient From 7a5d058a57415c7c7a343aa4c27c9124d3ef9a73 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:35:50 +0100 Subject: [PATCH 17/33] Update mongo_impl.py --- lightrag/kg/mongo_impl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index fbbae8c2..9515514a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,6 +1,10 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass +import pipmaster as pm +if not pm.is_installed("pymongo"): + pm.install("pymongo") + from pymongo import MongoClient from typing import Union from lightrag.utils import logger From 3fdeeff8ba6f1c5eba094c47ad39a5f206990773 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:36:10 +0100 Subject: [PATCH 18/33] Update neo4j_impl.py --- lightrag/kg/neo4j_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 4392a834..cd552122 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,6 +3,9 @@ import inspect import os from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict +import pipmaster as pm +if not pm.is_installed("neo4j"): + pm.install("neo4j") from neo4j import ( AsyncGraphDatabase, From ecadb71556d78c4d75edd22931ac8906e5042400 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:36:34 +0100 Subject: [PATCH 19/33] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index f93d2816..2d1f631c 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -6,6 +6,11 @@ from dataclasses import dataclass from typing import Union import numpy as np import array +import pipmaster as pm + +if not pm.is_installed("oracledb"): + pm.install("oracledb") + from ..utils import logger from ..base import ( From c7c56863b1df0c090789c10cc19273d460ad037a Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:36:53 +0100 Subject: [PATCH 20/33] Update postgres_impl.py --- lightrag/kg/postgres_impl.py | 1254 +++------------------------------- 1 file changed, 108 insertions(+), 1146 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 86072c9f..eb6e6e73 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1,30 +1,25 @@ import asyncio -import inspect -import json -import os -import time -from dataclasses import dataclass -from typing import Union, List, Dict, Set, Any, Tuple -import numpy as np -import asyncpg import sys -from tqdm.asyncio import tqdm as tqdm_async -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) +import os +import pipmaster as pm -from ..utils import logger -from ..base import ( - BaseKVStorage, - BaseVectorStorage, - DocStatusStorage, - DocStatus, - DocProcessingStatus, - BaseGraphStorage, -) +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + +import asyncpg +import psycopg +from psycopg_pool import AsyncConnectionPool +from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage + +DB = "rag" +USER = "rag" +PASSWORD = "rag" +HOST = "localhost" +PORT = "15432" +os.environ["AGE_GRAPH_NAME"] = "dickens" if sys.platform.startswith("win"): import asyncio.windows_events @@ -32,1143 +27,110 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -class PostgreSQLDB: - def __init__(self, config, **kwargs): - self.pool = None - self.host = config.get("host", "localhost") - self.port = config.get("port", 5432) - self.user = config.get("user", "postgres") - self.password = config.get("password", None) - self.database = config.get("database", "postgres") - self.workspace = config.get("workspace", "default") - self.max = 12 - self.increment = 1 - logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier") +async def get_pool(): + return await asyncpg.create_pool( + f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", + min_size=10, + max_size=10, + max_queries=5000, + max_inactive_connection_lifetime=300.0, + ) - if self.user is None or self.password is None or self.database is None: - raise ValueError( - "Missing database user, password, or database in addon_params" - ) - async def initdb(self): - try: - self.pool = await asyncpg.create_pool( - user=self.user, - password=self.password, - database=self.database, - host=self.host, - port=self.port, - min_size=1, - max_size=self.max, - ) +async def main1(): + connection_string = ( + f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" + ) + pool = AsyncConnectionPool(connection_string, open=False) + await pool.open() - logger.info( - f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}" - ) - except Exception as e: - logger.error( - f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}" - ) - logger.error(f"PostgreSQL database error: {e}") - raise - - async def check_tables(self): - for k, v in TABLES.items(): + try: + conn = await pool.getconn(timeout=10) + async with conn.cursor() as curs: try: - await self.query("SELECT 1 FROM {k} LIMIT 1".format(k=k)) - except Exception as e: - logger.error(f"Failed to check table {k} in PostgreSQL database") - logger.error(f"PostgreSQL database error: {e}") - try: - await self.execute(v["ddl"]) - logger.info(f"Created table {k} in PostgreSQL database") - except Exception as e: - logger.error(f"Failed to create table {k} in PostgreSQL database") - logger.error(f"PostgreSQL database error: {e}") - - logger.info("Finished checking all tables in PostgreSQL database") - - async def query( - self, - sql: str, - params: dict = None, - multirows: bool = False, - for_age: bool = False, - graph_name: str = None, - ) -> Union[dict, None, list[dict]]: - async with self.pool.acquire() as connection: - try: - if for_age: - await PostgreSQLDB._prerequisite(connection, graph_name) - if params: - rows = await connection.fetch(sql, *params.values()) - else: - rows = await connection.fetch(sql) - - if multirows: - if rows: - columns = [col for col in rows[0].keys()] - data = [dict(zip(columns, row)) for row in rows] - else: - data = [] - else: - if rows: - columns = rows[0].keys() - data = dict(zip(columns, rows[0])) - else: - data = None - return data - except Exception as e: - logger.error(f"PostgreSQL database error: {e}") - print(sql) - print(params) - raise - - async def execute( - self, - sql: str, - data: Union[list, dict] = None, - for_age: bool = False, - graph_name: str = None, - upsert: bool = False, - ): - try: - async with self.pool.acquire() as connection: - if for_age: - await PostgreSQLDB._prerequisite(connection, graph_name) - - if data is None: - await connection.execute(sql) - else: - await connection.execute(sql, *data.values()) - except ( - asyncpg.exceptions.UniqueViolationError, - asyncpg.exceptions.DuplicateTableError, - ) as e: - if upsert: - print("Key value duplicate, but upsert succeeded.") - else: - logger.error(f"Upsert error: {e}") - except Exception as e: - logger.error(f"PostgreSQL database error: {e.__class__} - {e}") - print(sql) - print(data) - raise - - @staticmethod - async def _prerequisite(conn: asyncpg.Connection, graph_name: str): - try: - await conn.execute('SET search_path = ag_catalog, "$user", public') - await conn.execute(f"""select create_graph('{graph_name}')""") - except ( - asyncpg.exceptions.InvalidSchemaNameError, - asyncpg.exceptions.UniqueViolationError, - ): - pass - - -@dataclass -class PGKVStorage(BaseKVStorage): - db: PostgreSQLDB = None - - def __post_init__(self): - self._max_batch_size = self.global_config["embedding_batch_num"] - - ################ QUERY METHODS ################ - - async def get_by_id(self, id: str) -> Union[dict, None]: - """Get doc_full data by id.""" - sql = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace": self.db.workspace, "id": id} - if "llm_response_cache" == self.namespace: - array_res = await self.db.query(sql, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - else: - res = await self.db.query(sql, params) - if res: - return res - else: - return None - - async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - """Specifically for llm_response_cache.""" - sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] - params = {"workspace": self.db.workspace, mode: mode, "id": id} - if "llm_response_cache" == self.namespace: - array_res = await self.db.query(sql, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - return res - else: - return None - - # Query by id - async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]: - """Get doc_chunks data by id""" - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - if "llm_response_cache" == self.namespace: - array_res = await self.db.query(sql, params, multirows=True) - modes = set() - dict_res: dict[str, dict] = {} - for row in array_res: - modes.add(row["mode"]) - for mode in modes: - if mode not in dict_res: - dict_res[mode] = {} - for row in array_res: - dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] - else: - res = await self.db.query(sql, params, multirows=True) - if res: - return res - else: - return None - - async def all_keys(self) -> list[dict]: - if "llm_response_cache" == self.namespace: - sql = "select workspace,mode,id from lightrag_llm_cache" - res = await self.db.query(sql, multirows=True) - return res - else: - logger.error( - f"all_keys is only implemented for llm_response_cache, not for {self.namespace}" - ) - - async def filter_keys(self, keys: List[str]) -> Set[str]: - """Filter out duplicated content""" - sql = SQL_TEMPLATES["filter_keys"].format( - table_name=NAMESPACE_TABLE_MAP[self.namespace], - ids=",".join([f"'{id}'" for id in keys]), - ) - params = {"workspace": self.db.workspace} - try: - res = await self.db.query(sql, params, multirows=True) - if res: - exist_keys = [key["id"] for key in res] - else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data - except Exception as e: - logger.error(f"PostgreSQL database error: {e}") - print(sql) - print(params) - - ################ INSERT METHODS ################ - async def upsert(self, data: Dict[str, dict]): - if self.namespace == "text_chunks": - pass - elif self.namespace == "full_docs": - for k, v in data.items(): - upsert_sql = SQL_TEMPLATES["upsert_doc_full"] - _data = { - "id": k, - "content": v["content"], - "workspace": self.db.workspace, - } - await self.db.execute(upsert_sql, _data) - elif self.namespace == "llm_response_cache": - for mode, items in data.items(): - for k, v in items.items(): - upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] - _data = { - "workspace": self.db.workspace, - "id": k, - "original_prompt": v["original_prompt"], - "return_value": v["return"], - "mode": mode, - } - - await self.db.execute(upsert_sql, _data) - - async def index_done_callback(self): - if self.namespace in ["full_docs", "text_chunks"]: - logger.info("full doc and chunk data had been saved into postgresql db!") - - -@dataclass -class PGVectorStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 - db: PostgreSQLDB = None - - def __post_init__(self): - self._max_batch_size = self.global_config["embedding_batch_num"] - self.cosine_better_than_threshold = self.global_config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) - - def _upsert_chunks(self, item: dict): - try: - upsert_sql = SQL_TEMPLATES["upsert_chunk"] - data = { - "workspace": self.db.workspace, - "id": item["__id__"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - } - except Exception as e: - logger.error(f"Error to prepare upsert sql: {e}") - print(item) - raise e - return upsert_sql, data - - def _upsert_entities(self, item: dict): - upsert_sql = SQL_TEMPLATES["upsert_entity"] - data = { - "workspace": self.db.workspace, - "id": item["__id__"], - "entity_name": item["entity_name"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - } - return upsert_sql, data - - def _upsert_relationships(self, item: dict): - upsert_sql = SQL_TEMPLATES["upsert_relationship"] - data = { - "workspace": self.db.workspace, - "id": item["__id__"], - "source_id": item["src_id"], - "target_id": item["tgt_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - } - return upsert_sql, data - - async def upsert(self, data: Dict[str, dict]): - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not len(data): - logger.warning("You insert an empty data to vector DB") - return [] - current_time = time.time() - list_data = [ - { - "__id__": k, - "__created_at__": current_time, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - - async def wrapped_task(batch): - result = await self.embedding_func(batch) - pbar.update(1) - return result - - embedding_tasks = [wrapped_task(batch) for batch in batches] - pbar = tqdm_async( - total=len(embedding_tasks), desc="Generating embeddings", unit="batch" - ) - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - for item in list_data: - if self.namespace == "chunks": - upsert_sql, data = self._upsert_chunks(item) - elif self.namespace == "entities": - upsert_sql, data = self._upsert_entities(item) - elif self.namespace == "relationships": - upsert_sql, data = self._upsert_relationships(item) - else: - raise ValueError(f"{self.namespace} is not supported") - - await self.db.execute(upsert_sql, data) - - async def index_done_callback(self): - logger.info("vector data had been saved into postgresql db!") - - #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" - embeddings = await self.embedding_func([query]) - embedding = embeddings[0] - embedding_string = ",".join(map(str, embedding)) - - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) - params = { - "workspace": self.db.workspace, - "better_than_threshold": self.cosine_better_than_threshold, - "top_k": top_k, - } - results = await self.db.query(sql, params=params, multirows=True) - return results - - -@dataclass -class PGDocStatusStorage(DocStatusStorage): - """PostgreSQL implementation of document status storage""" - - db: PostgreSQLDB = None - - def __post_init__(self): + await curs.execute('SET search_path = ag_catalog, "$user", public') + await curs.execute("SELECT create_graph('dickens-2')") + await conn.commit() + print("create_graph success") + except ( + psycopg.errors.InvalidSchemaName, + psycopg.errors.UniqueViolation, + ): + print("create_graph already exists") + await conn.rollback() + finally: pass - async def filter_keys(self, data: list[str]) -> set[str]: - """Return keys that don't exist in storage""" - keys = ",".join([f"'{_id}'" for _id in data]) - sql = ( - f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})" - ) - result = await self.db.query(sql, {"workspace": self.db.workspace}, True) - # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. - if result is None: - return set(data) - else: - existed = set([element["id"] for element in result]) - return set(data) - existed - async def get_status_counts(self) -> Dict[str, int]: - """Get counts of documents in each status""" - sql = """SELECT status as "status", COUNT(1) as "count" - FROM LIGHTRAG_DOC_STATUS - where workspace=$1 GROUP BY STATUS - """ - result = await self.db.query(sql, {"workspace": self.db.workspace}, True) - # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] - counts = {} - for doc in result: - counts[doc["status"]] = doc["count"] - return counts - - async def get_docs_by_status( - self, status: DocStatus - ) -> Dict[str, DocProcessingStatus]: - """Get all documents by status""" - sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1" - params = {"workspace": self.db.workspace, "status": status} - result = await self.db.query(sql, params, True) - # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] - # Converting to be a dict - return { - element["id"]: DocProcessingStatus( - content_summary=element["content_summary"], - content_length=element["content_length"], - status=element["status"], - created_at=element["created_at"], - updated_at=element["updated_at"], - chunks_count=element["chunks_count"], - ) - for element in result - } - - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def index_done_callback(self): - """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" - logger.info("Doc status had been saved into postgresql db!") - - async def upsert(self, data: dict[str, dict]): - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ - sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status) - values($1,$2,$3,$4,$5,$6) - on conflict(id,workspace) do update set - content_summary = EXCLUDED.content_summary, - content_length = EXCLUDED.content_length, - chunks_count = EXCLUDED.chunks_count, - status = EXCLUDED.status, - updated_at = CURRENT_TIMESTAMP""" - for k, v in data.items(): - # chunks_count is optional - await self.db.execute( - sql, - { - "workspace": self.db.workspace, - "id": k, - "content_summary": v["content_summary"], - "content_length": v["content_length"], - "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, - "status": v["status"], - }, - ) - return data - - -class PGGraphQueryException(Exception): - """Exception for the AGE queries.""" - - def __init__(self, exception: Union[str, Dict]) -> None: - if isinstance(exception, dict): - self.message = exception["message"] if "message" in exception else "unknown" - self.details = exception["details"] if "details" in exception else "unknown" - else: - self.message = exception - self.details = "unknown" - - def get_message(self) -> str: - return self.message - - def get_details(self) -> Any: - return self.details - - -@dataclass -class PGGraphStorage(BaseGraphStorage): - db: PostgreSQLDB = None - - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with AGE in production") - - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - self.graph_name = os.environ["AGE_GRAPH_NAME"] - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } - - async def index_done_callback(self): - print("KG successfully indexed.") - - @staticmethod - def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: - """ - Convert a record returned from an age query to a dictionary - - Args: - record (): a record from an age query result - - Returns: - Dict[str, Any]: a dictionary representation of the record where - the dictionary key is the field name and the value is the - value converted to a python type - """ - # result holder - d = {} - - # prebuild a mapping of vertex_id to vertex mappings to be used - # later to build edges - vertices = {} - for k in record.keys(): - v = record[k] - # agtype comes back '{key: value}::type' which must be parsed - if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - if dtype == "vertex": - vertex = json.loads(v) - vertices[vertex["id"]] = vertex.get("properties") - - # iterate returned fields and parse appropriately - for k in record.keys(): - v = record[k] - if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - else: - dtype = "" - - if dtype == "vertex": - vertex = json.loads(v) - field = vertex.get("properties") - if not field: - field = {} - field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) - d[k] = field - # convert edge from id-label->id by replacing id with node information - # we only do this if the vertex was also returned in the query - # this is an attempt to be consistent with neo4j implementation - elif dtype == "edge": - edge = json.loads(v) - d[k] = ( - vertices.get(edge["start_id"], {}), - edge[ - "label" - ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" - vertices.get(edge["end_id"], {}), - ) - else: - d[k] = json.loads(v) if isinstance(v, str) else v - - return d - - @staticmethod - def _format_properties( - properties: Dict[str, Any], _id: Union[str, None] = None - ) -> str: - """ - Convert a dictionary of properties to a string representation that - can be used in a cypher query insert/merge statement. - - Args: - properties (Dict[str,str]): a dictionary containing node/edge properties - _id (Union[str, None]): the id of the node or None if none exists - - Returns: - str: the properties dictionary as a properly formatted string - """ - props = [] - # wrap property key in backticks to escape - for k, v in properties.items(): - prop = f"`{k}`: {json.dumps(v)}" - props.append(prop) - if _id is not None and "id" not in properties: - props.append( - f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" - ) - return "{" + ", ".join(props) + "}" - - @staticmethod - def _encode_graph_label(label: str) -> str: - """ - Since AGE supports only alphanumerical labels, we will encode generic label as HEX string - - Args: - label (str): the original label - - Returns: - str: the encoded label - """ - return "x" + label.encode().hex() - - @staticmethod - def _decode_graph_label(encoded_label: str) -> str: - """ - Since AGE supports only alphanumerical labels, we will encode generic label as HEX string - - Args: - encoded_label (str): the encoded label - - Returns: - str: the decoded label - """ - return bytes.fromhex(encoded_label.removeprefix("x")).decode() - - @staticmethod - def _get_col_name(field: str, idx: int) -> str: - """ - Convert a cypher return field to a pgsql select field - If possible keep the cypher column name, but create a generic name if necessary - - Args: - field (str): a return field from a cypher query to be formatted for pgsql - idx (int): the position of the field in the return statement - - Returns: - str: the field to be used in the pgsql select statement - """ - # remove white space - field = field.strip() - # if an alias is provided for the field, use it - if " as " in field: - return field.split(" as ")[-1].strip() - # if the return value is an unnamed primitive, give it a generic name - if field.isnumeric() or field in ("true", "false", "null"): - return f"column_{idx}" - # otherwise return the value stripping out some common special chars - return field.replace("(", "_").replace(")", "") - - async def _query( - self, query: str, readonly: bool = True, upsert: bool = False - ) -> List[Dict[str, Any]]: - """ - Query the graph by taking a cypher query, converting it to an - age compatible query, executing it and converting the result - - Args: - query (str): a cypher query to be executed - params (dict): parameters for the query - - Returns: - List[Dict[str, Any]]: a list of dictionaries containing the result set - """ - # convert cypher query to pgsql/age query - wrapped_query = query - - # execute the query, rolling back on an error - try: - if readonly: - data = await self.db.query( - wrapped_query, - multirows=True, - for_age=True, - graph_name=self.graph_name, - ) - else: - data = await self.db.execute( - wrapped_query, - for_age=True, - graph_name=self.graph_name, - upsert=upsert, - ) - except Exception as e: - raise PGGraphQueryException( - { - "message": f"Error executing graph query: {query}", - "wrapped": wrapped_query, - "detail": str(e), - } - ) from e - - if data is None: - result = [] - # decode records - else: - result = [PGGraphStorage._record_to_dict(d) for d in data] - - return result - - async def has_node(self, node_id: str) -> bool: - entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - RETURN count(n) > 0 AS node_exists - $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) - - single_result = (await self._query(query))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - single_result["node_exists"], - ) - - return single_result["node_exists"] - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) - RETURN COUNT(r) > 0 AS edge_exists - $$) AS (edge_exists bool)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - - single_result = (await self._query(query))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - single_result["edge_exists"], - ) - return single_result["edge_exists"] - - async def get_node(self, node_id: str) -> Union[dict, None]: - label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - RETURN n - $$) AS (n agtype)""" % (self.graph_name, label) - record = await self._query(query) - if record: - node = record[0] - node_dict = node["n"] - logger.debug( - "{%s}: query: {%s}, result: {%s}", - inspect.currentframe().f_code.co_name, - query, - node_dict, - ) - return node_dict - return None - - async def node_degree(self, node_id: str) -> int: - label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"})-[]->(x) - RETURN count(x) AS total_edge_count - $$) AS (total_edge_count integer)""" % (self.graph_name, label) - record = (await self._query(query))[0] - if record: - edge_count = int(record["total_edge_count"]) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - edge_count, - ) - return edge_count - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - logger.debug( - "{%s}:query:src_Degree+trg_degree:result:{%s}", - inspect.currentframe().f_code.co_name, - degrees, - ) - return degrees - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_id (str): Label of the source nodes - target_node_id (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ - src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) - RETURN properties(r) as edge_properties - LIMIT 1 - $$) AS (edge_properties agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - record = await self._query(query) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - result, - ) - return result - - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: - """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected - $$) AS (n agtype, r agtype, connected agtype)""" % ( - self.graph_name, - label, - ) - - results = await self._query(query) - edges = [] - for record in results: - source_node = record["n"] if record["n"] else None - connected_node = record["connected"] if record["connected"] else None - - source_label = ( - source_node["node_id"] - if source_node and source_node["node_id"] - else None - ) - target_label = ( - connected_node["node_id"] - if connected_node and connected_node["node_id"] - else None - ) - - if source_label and target_label: - edges.append( - ( - PGGraphStorage._decode_graph_label(source_label), - PGGraphStorage._decode_graph_label(target_label), - ) - ) - - return edges - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((PGGraphQueryException,)), +db = PostgreSQLDB( + config={ + "host": "localhost", + "port": 15432, + "user": "rag", + "password": "rag", + "database": "r1", + } +) + + +async def query_with_age(): + await db.initdb() + graph = PGGraphStorage( + namespace="chunk_entity_relation", + global_config={}, + embedding_func=None, ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): - """ - Upsert a node in the AGE database. + graph.db = db + res = await graph.get_node('"A CHRISTMAS CAROL"') + print("Node is: ", res) + res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") + print("Edge is: ", res) + res = await graph.get_node_edges('"SCROOGE"') + print("Node Edges are: ", res) - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ - label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - properties = node_data - query = """SELECT * FROM cypher('%s', $$ - MERGE (n:Entity {node_id: "%s"}) - SET n += %s - RETURN n - $$) AS (n agtype)""" % ( - self.graph_name, - label, - PGGraphStorage._format_properties(properties), - ) - - try: - await self._query(query, readonly=False, upsert=True) - logger.debug( - "Upserted node with label '{%s}' and properties: {%s}", - label, - properties, - ) - except Exception as e: - logger.error("Error during upsert: {%s}", e) - raise - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((PGGraphQueryException,)), +async def create_edge_with_age(): + await db.initdb() + graph = PGGraphStorage( + namespace="chunk_entity_relation", + global_config={}, + embedding_func=None, ) - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): - """ - Upsert an edge and its properties between two nodes identified by their labels. + graph.db = db + await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"}) + await graph.upsert_node('"THE GIRLS"', {"world": "hello"}) + await graph.upsert_edge( + '"THE CRATCHITS"', + '"THE GIRLS"', + edge_data={ + "weight": 7.0, + "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.', + "keywords": '"family, collective effort"', + "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", + }, + ) + res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"') + print("Edge is: ", res) - Args: - source_node_id (str): Label of the source node (used as identifier) - target_node_id (str): Label of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge - """ - src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - edge_properties = edge_data - query = """SELECT * FROM cypher('%s', $$ - MATCH (source:Entity {node_id: "%s"}) - WITH source - MATCH (target:Entity {node_id: "%s"}) - MERGE (source)-[r:DIRECTED]->(target) - SET r += %s - RETURN r - $$) AS (r agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - PGGraphStorage._format_properties(edge_properties), - ) - # logger.info(f"-- inserting edge after formatted: {params}") +async def main(): + pool = await get_pool() + sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" + # cypher = "MATCH (n:how_are_you_doing) RETURN n" + async with pool.acquire() as conn: try: - await self._query(query, readonly=False, upsert=True) - logger.debug( - "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - src_label, - tgt_label, - edge_properties, + await conn.execute( + """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""" ) - except Exception as e: - logger.error("Error during edge upsert: {%s}", e) - raise + except asyncpg.exceptions.InvalidSchemaNameError: + print("create_graph already exists") + # stmt = await conn.prepare(sql) + row = await conn.fetch(sql) + print("row is: ", row) - async def _node2vec_embed(self): - print("Implemented but never called.") + row = await conn.fetchrow("select '100'::int + 200 as result") + print(row) # -NAMESPACE_TABLE_MAP = { - "full_docs": "LIGHTRAG_DOC_FULL", - "text_chunks": "LIGHTRAG_DOC_CHUNKS", - "chunks": "LIGHTRAG_DOC_CHUNKS", - "entities": "LIGHTRAG_VDB_ENTITY", - "relationships": "LIGHTRAG_VDB_RELATION", - "doc_status": "LIGHTRAG_DOC_STATUS", - "llm_response_cache": "LIGHTRAG_LLM_CACHE", -} - - -TABLES = { - "LIGHTRAG_DOC_FULL": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id VARCHAR(255), - workspace VARCHAR(255), - doc_name VARCHAR(1024), - content TEXT, - meta JSONB, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP, - CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_DOC_CHUNKS": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id VARCHAR(255), - workspace VARCHAR(255), - full_doc_id VARCHAR(256), - chunk_order_index INTEGER, - tokens INTEGER, - content TEXT, - content_vector VECTOR, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP, - CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_VDB_ENTITY": { - "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY ( - id VARCHAR(255), - workspace VARCHAR(255), - entity_name VARCHAR(255), - content TEXT, - content_vector VECTOR, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP, - CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_VDB_RELATION": { - "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION ( - id VARCHAR(255), - workspace VARCHAR(255), - source_id VARCHAR(256), - target_id VARCHAR(256), - content TEXT, - content_vector VECTOR, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP, - CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_LLM_CACHE": { - "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - workspace varchar(255) NOT NULL, - id varchar(255) NOT NULL, - mode varchar(32) NOT NULL, - original_prompt TEXT, - return_value TEXT, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP, - CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) - )""" - }, - "LIGHTRAG_DOC_STATUS": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS ( - workspace varchar(255) NOT NULL, - id varchar(255) NOT NULL, - content_summary varchar(255) NULL, - content_length int4 NULL, - chunks_count int4 NULL, - status varchar(64) NULL, - created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL, - updated_at timestamp DEFAULT CURRENT_TIMESTAMP NULL, - CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) - )""" - }, -} - - -SQL_TEMPLATES = { - # SQL for KVStorage - "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content - FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 - """, - "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 - """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 - """, - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 - """, - "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content - FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) - """, - "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) - """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) - """, - "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", - "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) - VALUES ($1, $2, $3) - ON CONFLICT (workspace,id) DO UPDATE - SET content = $2, update_time = CURRENT_TIMESTAMP - """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (workspace,mode,id) DO UPDATE - SET original_prompt = EXCLUDED.original_prompt, - return_value=EXCLUDED.return_value, - mode=EXCLUDED.mode, - update_time = CURRENT_TIMESTAMP - """, - "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, - chunk_order_index, full_doc_id, content, content_vector) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (workspace,id) DO UPDATE - SET tokens=EXCLUDED.tokens, - chunk_order_index=EXCLUDED.chunk_order_index, - full_doc_id=EXCLUDED.full_doc_id, - content = EXCLUDED.content, - content_vector=EXCLUDED.content_vector, - update_time = CURRENT_TIMESTAMP - """, - "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (workspace,id) DO UPDATE - SET entity_name=EXCLUDED.entity_name, - content=EXCLUDED.content, - content_vector=EXCLUDED.content_vector, - update_time=CURRENT_TIMESTAMP - """, - "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, - target_id, content, content_vector) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (workspace,id) DO UPDATE - SET source_id=EXCLUDED.source_id, - target_id=EXCLUDED.target_id, - content=EXCLUDED.content, - content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP - """, - # SQL for VectorStorage - "entities": """SELECT entity_name FROM - (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_ENTITY where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, - "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM - (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_RELATION where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, - "chunks": """SELECT id FROM - (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) - WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """, -} +if __name__ == "__main__": + asyncio.run(query_with_age()) From 57682389e2cded5a837c6ce700ed28c2c4792298 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:38:32 +0100 Subject: [PATCH 21/33] Update postgres_impl_test.py --- lightrag/kg/postgres_impl_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py index 274f03de..eb6e6e73 100644 --- a/lightrag/kg/postgres_impl_test.py +++ b/lightrag/kg/postgres_impl_test.py @@ -1,8 +1,15 @@ import asyncio -import asyncpg import sys import os +import pipmaster as pm +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + +import asyncpg import psycopg from psycopg_pool import AsyncConnectionPool from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage From b6068046ffc7872ec99f9c8f74a5050ccccb3df8 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:39:39 +0100 Subject: [PATCH 22/33] Update postgres_impl.py --- lightrag/kg/postgres_impl.py | 1261 +++++++++++++++++++++++++++++++--- 1 file changed, 1152 insertions(+), 109 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index eb6e6e73..efeb7cf5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1,25 +1,35 @@ import asyncio -import sys +import inspect +import json import os -import pipmaster as pm +import time +from dataclasses import dataclass +from typing import Union, List, Dict, Set, Any, Tuple +import numpy as np -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") +import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") import asyncpg -import psycopg -from psycopg_pool import AsyncConnectionPool -from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage +import sys +from tqdm.asyncio import tqdm as tqdm_async +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) -DB = "rag" -USER = "rag" -PASSWORD = "rag" -HOST = "localhost" -PORT = "15432" -os.environ["AGE_GRAPH_NAME"] = "dickens" +from ..utils import logger +from ..base import ( + BaseKVStorage, + BaseVectorStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, + BaseGraphStorage, +) if sys.platform.startswith("win"): import asyncio.windows_events @@ -27,110 +37,1143 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -async def get_pool(): - return await asyncpg.create_pool( - f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", - min_size=10, - max_size=10, - max_queries=5000, - max_inactive_connection_lifetime=300.0, - ) +class PostgreSQLDB: + def __init__(self, config, **kwargs): + self.pool = None + self.host = config.get("host", "localhost") + self.port = config.get("port", 5432) + self.user = config.get("user", "postgres") + self.password = config.get("password", None) + self.database = config.get("database", "postgres") + self.workspace = config.get("workspace", "default") + self.max = 12 + self.increment = 1 + logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier") + if self.user is None or self.password is None or self.database is None: + raise ValueError( + "Missing database user, password, or database in addon_params" + ) -async def main1(): - connection_string = ( - f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" - ) - pool = AsyncConnectionPool(connection_string, open=False) - await pool.open() + async def initdb(self): + try: + self.pool = await asyncpg.create_pool( + user=self.user, + password=self.password, + database=self.database, + host=self.host, + port=self.port, + min_size=1, + max_size=self.max, + ) - try: - conn = await pool.getconn(timeout=10) - async with conn.cursor() as curs: + logger.info( + f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}" + ) + except Exception as e: + logger.error( + f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}" + ) + logger.error(f"PostgreSQL database error: {e}") + raise + + async def check_tables(self): + for k, v in TABLES.items(): try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute("SELECT create_graph('dickens-2')") - await conn.commit() - print("create_graph success") - except ( - psycopg.errors.InvalidSchemaName, - psycopg.errors.UniqueViolation, - ): - print("create_graph already exists") - await conn.rollback() - finally: + await self.query("SELECT 1 FROM {k} LIMIT 1".format(k=k)) + except Exception as e: + logger.error(f"Failed to check table {k} in PostgreSQL database") + logger.error(f"PostgreSQL database error: {e}") + try: + await self.execute(v["ddl"]) + logger.info(f"Created table {k} in PostgreSQL database") + except Exception as e: + logger.error(f"Failed to create table {k} in PostgreSQL database") + logger.error(f"PostgreSQL database error: {e}") + + logger.info("Finished checking all tables in PostgreSQL database") + + async def query( + self, + sql: str, + params: dict = None, + multirows: bool = False, + for_age: bool = False, + graph_name: str = None, + ) -> Union[dict, None, list[dict]]: + async with self.pool.acquire() as connection: + try: + if for_age: + await PostgreSQLDB._prerequisite(connection, graph_name) + if params: + rows = await connection.fetch(sql, *params.values()) + else: + rows = await connection.fetch(sql) + + if multirows: + if rows: + columns = [col for col in rows[0].keys()] + data = [dict(zip(columns, row)) for row in rows] + else: + data = [] + else: + if rows: + columns = rows[0].keys() + data = dict(zip(columns, rows[0])) + else: + data = None + return data + except Exception as e: + logger.error(f"PostgreSQL database error: {e}") + print(sql) + print(params) + raise + + async def execute( + self, + sql: str, + data: Union[list, dict] = None, + for_age: bool = False, + graph_name: str = None, + upsert: bool = False, + ): + try: + async with self.pool.acquire() as connection: + if for_age: + await PostgreSQLDB._prerequisite(connection, graph_name) + + if data is None: + await connection.execute(sql) + else: + await connection.execute(sql, *data.values()) + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + ) as e: + if upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") + except Exception as e: + logger.error(f"PostgreSQL database error: {e.__class__} - {e}") + print(sql) + print(data) + raise + + @staticmethod + async def _prerequisite(conn: asyncpg.Connection, graph_name: str): + try: + await conn.execute('SET search_path = ag_catalog, "$user", public') + await conn.execute(f"""select create_graph('{graph_name}')""") + except ( + asyncpg.exceptions.InvalidSchemaNameError, + asyncpg.exceptions.UniqueViolationError, + ): + pass + + +@dataclass +class PGKVStorage(BaseKVStorage): + db: PostgreSQLDB = None + + def __post_init__(self): + self._max_batch_size = self.global_config["embedding_batch_num"] + + ################ QUERY METHODS ################ + + async def get_by_id(self, id: str) -> Union[dict, None]: + """Get doc_full data by id.""" + sql = SQL_TEMPLATES["get_by_id_" + self.namespace] + params = {"workspace": self.db.workspace, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(sql, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + else: + res = await self.db.query(sql, params) + if res: + return res + else: + return None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, mode: mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(sql, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res + else: + return None + + # Query by id + async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]: + """Get doc_chunks data by id""" + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(sql, params, multirows=True) + modes = set() + dict_res: dict[str, dict] = {} + for row in array_res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in array_res: + dict_res[row["mode"]][row["id"]] = row + res = [{k: v} for k, v in dict_res.items()] + else: + res = await self.db.query(sql, params, multirows=True) + if res: + return res + else: + return None + + async def all_keys(self) -> list[dict]: + if "llm_response_cache" == self.namespace: + sql = "select workspace,mode,id from lightrag_llm_cache" + res = await self.db.query(sql, multirows=True) + return res + else: + logger.error( + f"all_keys is only implemented for llm_response_cache, not for {self.namespace}" + ) + + async def filter_keys(self, keys: List[str]) -> Set[str]: + """Filter out duplicated content""" + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=NAMESPACE_TABLE_MAP[self.namespace], + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.db.workspace} + try: + res = await self.db.query(sql, params, multirows=True) + if res: + exist_keys = [key["id"] for key in res] + else: + exist_keys = [] + data = set([s for s in keys if s not in exist_keys]) + return data + except Exception as e: + logger.error(f"PostgreSQL database error: {e}") + print(sql) + print(params) + + ################ INSERT METHODS ################ + async def upsert(self, data: Dict[str, dict]): + if self.namespace == "text_chunks": + pass + elif self.namespace == "full_docs": + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_doc_full"] + _data = { + "id": k, + "content": v["content"], + "workspace": self.db.workspace, + } + await self.db.execute(upsert_sql, _data) + elif self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "mode": mode, + } + + await self.db.execute(upsert_sql, _data) + + async def index_done_callback(self): + if self.namespace in ["full_docs", "text_chunks"]: + logger.info("full doc and chunk data had been saved into postgresql db!") + + +@dataclass +class PGVectorStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + db: PostgreSQLDB = None + + def __post_init__(self): + self._max_batch_size = self.global_config["embedding_batch_num"] + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + def _upsert_chunks(self, item: dict): + try: + upsert_sql = SQL_TEMPLATES["upsert_chunk"] + data = { + "workspace": self.db.workspace, + "id": item["__id__"], + "tokens": item["tokens"], + "chunk_order_index": item["chunk_order_index"], + "full_doc_id": item["full_doc_id"], + "content": item["content"], + "content_vector": json.dumps(item["__vector__"].tolist()), + } + except Exception as e: + logger.error(f"Error to prepare upsert sql: {e}") + print(item) + raise e + return upsert_sql, data + + def _upsert_entities(self, item: dict): + upsert_sql = SQL_TEMPLATES["upsert_entity"] + data = { + "workspace": self.db.workspace, + "id": item["__id__"], + "entity_name": item["entity_name"], + "content": item["content"], + "content_vector": json.dumps(item["__vector__"].tolist()), + } + return upsert_sql, data + + def _upsert_relationships(self, item: dict): + upsert_sql = SQL_TEMPLATES["upsert_relationship"] + data = { + "workspace": self.db.workspace, + "id": item["__id__"], + "source_id": item["src_id"], + "target_id": item["tgt_id"], + "content": item["content"], + "content_vector": json.dumps(item["__vector__"].tolist()), + } + return upsert_sql, data + + async def upsert(self, data: Dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + current_time = time.time() + list_data = [ + { + "__id__": k, + "__created_at__": current_time, + **{k1: v1 for k1, v1 in v.items()}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + for item in list_data: + if self.namespace == "chunks": + upsert_sql, data = self._upsert_chunks(item) + elif self.namespace == "entities": + upsert_sql, data = self._upsert_entities(item) + elif self.namespace == "relationships": + upsert_sql, data = self._upsert_relationships(item) + else: + raise ValueError(f"{self.namespace} is not supported") + + await self.db.execute(upsert_sql, data) + + async def index_done_callback(self): + logger.info("vector data had been saved into postgresql db!") + + #################### query method ############### + async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: + """从向量数据库中查询数据""" + embeddings = await self.embedding_func([query]) + embedding = embeddings[0] + embedding_string = ",".join(map(str, embedding)) + + sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + params = { + "workspace": self.db.workspace, + "better_than_threshold": self.cosine_better_than_threshold, + "top_k": top_k, + } + results = await self.db.query(sql, params=params, multirows=True) + return results + + +@dataclass +class PGDocStatusStorage(DocStatusStorage): + """PostgreSQL implementation of document status storage""" + + db: PostgreSQLDB = None + + def __post_init__(self): pass + async def filter_keys(self, data: list[str]) -> set[str]: + """Return keys that don't exist in storage""" + keys = ",".join([f"'{_id}'" for _id in data]) + sql = ( + f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})" + ) + result = await self.db.query(sql, {"workspace": self.db.workspace}, True) + # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. + if result is None: + return set(data) + else: + existed = set([element["id"] for element in result]) + return set(data) - existed -db = PostgreSQLDB( - config={ - "host": "localhost", - "port": 15432, - "user": "rag", - "password": "rag", - "database": "r1", - } -) + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + sql = """SELECT status as "status", COUNT(1) as "count" + FROM LIGHTRAG_DOC_STATUS + where workspace=$1 GROUP BY STATUS + """ + result = await self.db.query(sql, {"workspace": self.db.workspace}, True) + # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] + counts = {} + for doc in result: + counts[doc["status"]] = doc["count"] + return counts - -async def query_with_age(): - await db.initdb() - graph = PGGraphStorage( - namespace="chunk_entity_relation", - global_config={}, - embedding_func=None, - ) - graph.db = db - res = await graph.get_node('"A CHRISTMAS CAROL"') - print("Node is: ", res) - res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") - print("Edge is: ", res) - res = await graph.get_node_edges('"SCROOGE"') - print("Node Edges are: ", res) - - -async def create_edge_with_age(): - await db.initdb() - graph = PGGraphStorage( - namespace="chunk_entity_relation", - global_config={}, - embedding_func=None, - ) - graph.db = db - await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"}) - await graph.upsert_node('"THE GIRLS"', {"world": "hello"}) - await graph.upsert_edge( - '"THE CRATCHITS"', - '"THE GIRLS"', - edge_data={ - "weight": 7.0, - "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.', - "keywords": '"family, collective effort"', - "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", - }, - ) - res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"') - print("Edge is: ", res) - - -async def main(): - pool = await get_pool() - sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" - # cypher = "MATCH (n:how_are_you_doing) RETURN n" - async with pool.acquire() as conn: - try: - await conn.execute( - """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""" + async def get_docs_by_status( + self, status: DocStatus + ) -> Dict[str, DocProcessingStatus]: + """Get all documents by status""" + sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1" + params = {"workspace": self.db.workspace, "status": status} + result = await self.db.query(sql, params, True) + # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] + # Converting to be a dict + return { + element["id"]: DocProcessingStatus( + content_summary=element["content_summary"], + content_length=element["content_length"], + status=element["status"], + created_at=element["created_at"], + updated_at=element["updated_at"], + chunks_count=element["chunks_count"], ) - except asyncpg.exceptions.InvalidSchemaNameError: - print("create_graph already exists") - # stmt = await conn.prepare(sql) - row = await conn.fetch(sql) - print("row is: ", row) + for element in result + } - row = await conn.fetchrow("select '100'::int + 200 as result") - print(row) # + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return await self.get_docs_by_status(DocStatus.FAILED) + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return await self.get_docs_by_status(DocStatus.PENDING) + + async def index_done_callback(self): + """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" + logger.info("Doc status had been saved into postgresql db!") + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status) + values($1,$2,$3,$4,$5,$6) + on conflict(id,workspace) do update set + content_summary = EXCLUDED.content_summary, + content_length = EXCLUDED.content_length, + chunks_count = EXCLUDED.chunks_count, + status = EXCLUDED.status, + updated_at = CURRENT_TIMESTAMP""" + for k, v in data.items(): + # chunks_count is optional + await self.db.execute( + sql, + { + "workspace": self.db.workspace, + "id": k, + "content_summary": v["content_summary"], + "content_length": v["content_length"], + "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, + "status": v["status"], + }, + ) + return data -if __name__ == "__main__": - asyncio.run(query_with_age()) +class PGGraphQueryException(Exception): + """Exception for the AGE queries.""" + + def __init__(self, exception: Union[str, Dict]) -> None: + if isinstance(exception, dict): + self.message = exception["message"] if "message" in exception else "unknown" + self.details = exception["details"] if "details" in exception else "unknown" + else: + self.message = exception + self.details = "unknown" + + def get_message(self) -> str: + return self.message + + def get_details(self) -> Any: + return self.details + + +@dataclass +class PGGraphStorage(BaseGraphStorage): + db: PostgreSQLDB = None + + @staticmethod + def load_nx_graph(file_name): + print("no preloading of graph with AGE in production") + + def __init__(self, namespace, global_config, embedding_func): + super().__init__( + namespace=namespace, + global_config=global_config, + embedding_func=embedding_func, + ) + self.graph_name = os.environ["AGE_GRAPH_NAME"] + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def index_done_callback(self): + print("KG successfully indexed.") + + @staticmethod + def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: + """ + Convert a record returned from an age query to a dictionary + + Args: + record (): a record from an age query result + + Returns: + Dict[str, Any]: a dictionary representation of the record where + the dictionary key is the field name and the value is the + value converted to a python type + """ + # result holder + d = {} + + # prebuild a mapping of vertex_id to vertex mappings to be used + # later to build edges + vertices = {} + for k in record.keys(): + v = record[k] + # agtype comes back '{key: value}::type' which must be parsed + if isinstance(v, str) and "::" in v: + dtype = v.split("::")[-1] + v = v.split("::")[0] + if dtype == "vertex": + vertex = json.loads(v) + vertices[vertex["id"]] = vertex.get("properties") + + # iterate returned fields and parse appropriately + for k in record.keys(): + v = record[k] + if isinstance(v, str) and "::" in v: + dtype = v.split("::")[-1] + v = v.split("::")[0] + else: + dtype = "" + + if dtype == "vertex": + vertex = json.loads(v) + field = vertex.get("properties") + if not field: + field = {} + field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) + d[k] = field + # convert edge from id-label->id by replacing id with node information + # we only do this if the vertex was also returned in the query + # this is an attempt to be consistent with neo4j implementation + elif dtype == "edge": + edge = json.loads(v) + d[k] = ( + vertices.get(edge["start_id"], {}), + edge[ + "label" + ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" + vertices.get(edge["end_id"], {}), + ) + else: + d[k] = json.loads(v) if isinstance(v, str) else v + + return d + + @staticmethod + def _format_properties( + properties: Dict[str, Any], _id: Union[str, None] = None + ) -> str: + """ + Convert a dictionary of properties to a string representation that + can be used in a cypher query insert/merge statement. + + Args: + properties (Dict[str,str]): a dictionary containing node/edge properties + _id (Union[str, None]): the id of the node or None if none exists + + Returns: + str: the properties dictionary as a properly formatted string + """ + props = [] + # wrap property key in backticks to escape + for k, v in properties.items(): + prop = f"`{k}`: {json.dumps(v)}" + props.append(prop) + if _id is not None and "id" not in properties: + props.append( + f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" + ) + return "{" + ", ".join(props) + "}" + + @staticmethod + def _encode_graph_label(label: str) -> str: + """ + Since AGE supports only alphanumerical labels, we will encode generic label as HEX string + + Args: + label (str): the original label + + Returns: + str: the encoded label + """ + return "x" + label.encode().hex() + + @staticmethod + def _decode_graph_label(encoded_label: str) -> str: + """ + Since AGE supports only alphanumerical labels, we will encode generic label as HEX string + + Args: + encoded_label (str): the encoded label + + Returns: + str: the decoded label + """ + return bytes.fromhex(encoded_label.removeprefix("x")).decode() + + @staticmethod + def _get_col_name(field: str, idx: int) -> str: + """ + Convert a cypher return field to a pgsql select field + If possible keep the cypher column name, but create a generic name if necessary + + Args: + field (str): a return field from a cypher query to be formatted for pgsql + idx (int): the position of the field in the return statement + + Returns: + str: the field to be used in the pgsql select statement + """ + # remove white space + field = field.strip() + # if an alias is provided for the field, use it + if " as " in field: + return field.split(" as ")[-1].strip() + # if the return value is an unnamed primitive, give it a generic name + if field.isnumeric() or field in ("true", "false", "null"): + return f"column_{idx}" + # otherwise return the value stripping out some common special chars + return field.replace("(", "_").replace(")", "") + + async def _query( + self, query: str, readonly: bool = True, upsert: bool = False + ) -> List[Dict[str, Any]]: + """ + Query the graph by taking a cypher query, converting it to an + age compatible query, executing it and converting the result + + Args: + query (str): a cypher query to be executed + params (dict): parameters for the query + + Returns: + List[Dict[str, Any]]: a list of dictionaries containing the result set + """ + # convert cypher query to pgsql/age query + wrapped_query = query + + # execute the query, rolling back on an error + try: + if readonly: + data = await self.db.query( + wrapped_query, + multirows=True, + for_age=True, + graph_name=self.graph_name, + ) + else: + data = await self.db.execute( + wrapped_query, + for_age=True, + graph_name=self.graph_name, + upsert=upsert, + ) + except Exception as e: + raise PGGraphQueryException( + { + "message": f"Error executing graph query: {query}", + "wrapped": wrapped_query, + "detail": str(e), + } + ) from e + + if data is None: + result = [] + # decode records + else: + result = [PGGraphStorage._record_to_dict(d) for d in data] + + return result + + async def has_node(self, node_id: str) -> bool: + entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + RETURN count(n) > 0 AS node_exists + $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) + + single_result = (await self._query(query))[0] + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + single_result["node_exists"], + ) + + return single_result["node_exists"] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) + RETURN COUNT(r) > 0 AS edge_exists + $$) AS (edge_exists bool)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + + single_result = (await self._query(query))[0] + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + single_result["edge_exists"], + ) + return single_result["edge_exists"] + + async def get_node(self, node_id: str) -> Union[dict, None]: + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + RETURN n + $$) AS (n agtype)""" % (self.graph_name, label) + record = await self._query(query) + if record: + node = record[0] + node_dict = node["n"] + logger.debug( + "{%s}: query: {%s}, result: {%s}", + inspect.currentframe().f_code.co_name, + query, + node_dict, + ) + return node_dict + return None + + async def node_degree(self, node_id: str) -> int: + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"})-[]->(x) + RETURN count(x) AS total_edge_count + $$) AS (total_edge_count integer)""" % (self.graph_name, label) + record = (await self._query(query))[0] + if record: + edge_count = int(record["total_edge_count"]) + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + edge_count, + ) + return edge_count + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + logger.debug( + "{%s}:query:src_Degree+trg_degree:result:{%s}", + inspect.currentframe().f_code.co_name, + degrees, + ) + return degrees + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + """ + Find all edges between nodes of two given labels + + Args: + source_node_id (str): Label of the source nodes + target_node_id (str): Label of the target nodes + + Returns: + list: List of all relationships/edges found + """ + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) + RETURN properties(r) as edge_properties + LIMIT 1 + $$) AS (edge_properties agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + record = await self._query(query) + if record and record[0] and record[0]["edge_properties"]: + result = record[0]["edge_properties"] + logger.debug( + "{%s}:query:{%s}:result:{%s}", + inspect.currentframe().f_code.co_name, + query, + result, + ) + return result + + async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + """ + Retrieves all edges (relationships) for a particular node identified by its label. + :return: List of dictionaries containing edge information + """ + label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + OPTIONAL MATCH (n)-[r]-(connected) + RETURN n, r, connected + $$) AS (n agtype, r agtype, connected agtype)""" % ( + self.graph_name, + label, + ) + + results = await self._query(query) + edges = [] + for record in results: + source_node = record["n"] if record["n"] else None + connected_node = record["connected"] if record["connected"] else None + + source_label = ( + source_node["node_id"] + if source_node and source_node["node_id"] + else None + ) + target_label = ( + connected_node["node_id"] + if connected_node and connected_node["node_id"] + else None + ) + + if source_label and target_label: + edges.append( + ( + PGGraphStorage._decode_graph_label(source_label), + PGGraphStorage._decode_graph_label(target_label), + ) + ) + + return edges + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((PGGraphQueryException,)), + ) + async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + """ + Upsert a node in the AGE database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + properties = node_data + + query = """SELECT * FROM cypher('%s', $$ + MERGE (n:Entity {node_id: "%s"}) + SET n += %s + RETURN n + $$) AS (n agtype)""" % ( + self.graph_name, + label, + PGGraphStorage._format_properties(properties), + ) + + try: + await self._query(query, readonly=False, upsert=True) + logger.debug( + "Upserted node with label '{%s}' and properties: {%s}", + label, + properties, + ) + except Exception as e: + logger.error("Error during upsert: {%s}", e) + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((PGGraphQueryException,)), + ) + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] + ): + """ + Upsert an edge and its properties between two nodes identified by their labels. + + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge + """ + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) + edge_properties = edge_data + + query = """SELECT * FROM cypher('%s', $$ + MATCH (source:Entity {node_id: "%s"}) + WITH source + MATCH (target:Entity {node_id: "%s"}) + MERGE (source)-[r:DIRECTED]->(target) + SET r += %s + RETURN r + $$) AS (r agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + PGGraphStorage._format_properties(edge_properties), + ) + # logger.info(f"-- inserting edge after formatted: {params}") + try: + await self._query(query, readonly=False, upsert=True) + logger.debug( + "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", + src_label, + tgt_label, + edge_properties, + ) + except Exception as e: + logger.error("Error during edge upsert: {%s}", e) + raise + + async def _node2vec_embed(self): + print("Implemented but never called.") + + +NAMESPACE_TABLE_MAP = { + "full_docs": "LIGHTRAG_DOC_FULL", + "text_chunks": "LIGHTRAG_DOC_CHUNKS", + "chunks": "LIGHTRAG_DOC_CHUNKS", + "entities": "LIGHTRAG_VDB_ENTITY", + "relationships": "LIGHTRAG_VDB_RELATION", + "doc_status": "LIGHTRAG_DOC_STATUS", + "llm_response_cache": "LIGHTRAG_LLM_CACHE", +} + + +TABLES = { + "LIGHTRAG_DOC_FULL": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( + id VARCHAR(255), + workspace VARCHAR(255), + doc_name VARCHAR(1024), + content TEXT, + meta JSONB, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_DOC_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + id VARCHAR(255), + workspace VARCHAR(255), + full_doc_id VARCHAR(256), + chunk_order_index INTEGER, + tokens INTEGER, + content TEXT, + content_vector VECTOR, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_VDB_ENTITY": { + "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY ( + id VARCHAR(255), + workspace VARCHAR(255), + entity_name VARCHAR(255), + content TEXT, + content_vector VECTOR, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_VDB_RELATION": { + "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION ( + id VARCHAR(255), + workspace VARCHAR(255), + source_id VARCHAR(256), + target_id VARCHAR(256), + content TEXT, + content_vector VECTOR, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_LLM_CACHE": { + "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( + workspace varchar(255) NOT NULL, + id varchar(255) NOT NULL, + mode varchar(32) NOT NULL, + original_prompt TEXT, + return_value TEXT, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) + )""" + }, + "LIGHTRAG_DOC_STATUS": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS ( + workspace varchar(255) NOT NULL, + id varchar(255) NOT NULL, + content_summary varchar(255) NULL, + content_length int4 NULL, + chunks_count int4 NULL, + status varchar(64) NULL, + created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL, + updated_at timestamp DEFAULT CURRENT_TIMESTAMP NULL, + CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) + )""" + }, +} + + +SQL_TEMPLATES = { + # SQL for KVStorage + "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content + FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 + """, + "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, + chunk_order_index, full_doc_id + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 + """, + "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 + """, + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 + """, + "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content + FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) + """, + "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, + chunk_order_index, full_doc_id + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) + """, + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) + """, + "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", + "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) + VALUES ($1, $2, $3) + ON CONFLICT (workspace,id) DO UPDATE + SET content = $2, update_time = CURRENT_TIMESTAMP + """, + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (workspace,mode,id) DO UPDATE + SET original_prompt = EXCLUDED.original_prompt, + return_value=EXCLUDED.return_value, + mode=EXCLUDED.mode, + update_time = CURRENT_TIMESTAMP + """, + "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, + chunk_order_index, full_doc_id, content, content_vector) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (workspace,id) DO UPDATE + SET tokens=EXCLUDED.tokens, + chunk_order_index=EXCLUDED.chunk_order_index, + full_doc_id=EXCLUDED.full_doc_id, + content = EXCLUDED.content, + content_vector=EXCLUDED.content_vector, + update_time = CURRENT_TIMESTAMP + """, + "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (workspace,id) DO UPDATE + SET entity_name=EXCLUDED.entity_name, + content=EXCLUDED.content, + content_vector=EXCLUDED.content_vector, + update_time=CURRENT_TIMESTAMP + """, + "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, + target_id, content, content_vector) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (workspace,id) DO UPDATE + SET source_id=EXCLUDED.source_id, + target_id=EXCLUDED.target_id, + content=EXCLUDED.content, + content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP + """, + # SQL for VectorStorage + "entities": """SELECT entity_name FROM + (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_ENTITY where workspace=$1) + WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + """, + "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM + (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_RELATION where workspace=$1) + WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + """, + "chunks": """SELECT id FROM + (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) + WHERE distance>$2 ORDER BY distance DESC LIMIT $3 + """, +} From b2c1144219d99ad0de0114d56993626b060c5cd9 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:39:58 +0100 Subject: [PATCH 23/33] Update redis_impl.py --- lightrag/kg/redis_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index a126074d..013196e3 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,6 +1,9 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass +import pipmaster as pm +if not pm.is_installed("redis"): + pm.install("redis") # aioredis is a depricated library, replaced with redis from redis.asyncio import Redis From 52037205ebd8dc24485480d32e794ae556f5c7e0 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:40:50 +0100 Subject: [PATCH 24/33] Update tidb_impl.py --- lightrag/kg/tidb_impl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 2cf698e1..8ba1de65 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -4,13 +4,18 @@ from dataclasses import dataclass from typing import Union import numpy as np +import pipmaster as pm +if not pm.is_installed("pymysql"): + pm.install("pymysql") +if not pm.is_installed("sqlalchemy"): + pm.install("sqlalchemy") + from sqlalchemy import create_engine, text from tqdm import tqdm from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage from lightrag.utils import logger - class TiDB(object): def __init__(self, config, **kwargs): self.host = config.get("host", None) From 56e9c9f4d5e8c4672d472e8393226376bd612729 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 09:59:26 +0100 Subject: [PATCH 25/33] Moved the storages to kg folder --- .../{storage/json_kv_storage.py => kg/json_kv_impl.py} | 0 .../jsondocstatus_storage.py => kg/jsondocstatus_impl.py} | 0 .../nano_vector_db.py => kg/nano_vector_db_impl.py} | 0 .../{storage/networkx_storage.py => kg/networkx_impl.py} | 0 lightrag/lightrag.py | 8 ++++---- lightrag/storage/__init__.py | 1 - 6 files changed, 4 insertions(+), 5 deletions(-) rename lightrag/{storage/json_kv_storage.py => kg/json_kv_impl.py} (100%) rename lightrag/{storage/jsondocstatus_storage.py => kg/jsondocstatus_impl.py} (100%) rename lightrag/{storage/nano_vector_db.py => kg/nano_vector_db_impl.py} (100%) rename lightrag/{storage/networkx_storage.py => kg/networkx_impl.py} (100%) delete mode 100644 lightrag/storage/__init__.py diff --git a/lightrag/storage/json_kv_storage.py b/lightrag/kg/json_kv_impl.py similarity index 100% rename from lightrag/storage/json_kv_storage.py rename to lightrag/kg/json_kv_impl.py diff --git a/lightrag/storage/jsondocstatus_storage.py b/lightrag/kg/jsondocstatus_impl.py similarity index 100% rename from lightrag/storage/jsondocstatus_storage.py rename to lightrag/kg/jsondocstatus_impl.py diff --git a/lightrag/storage/nano_vector_db.py b/lightrag/kg/nano_vector_db_impl.py similarity index 100% rename from lightrag/storage/nano_vector_db.py rename to lightrag/kg/nano_vector_db_impl.py diff --git a/lightrag/storage/networkx_storage.py b/lightrag/kg/networkx_impl.py similarity index 100% rename from lightrag/storage/networkx_storage.py rename to lightrag/kg/networkx_impl.py diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9a849921..b40eecaa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -38,10 +38,10 @@ from .base import ( from .prompt import GRAPH_FIELD_SEP STORAGES = { - "NetworkXStorage": ".storage.networkx_storage", - "JsonKVStorage": ".storage.json_kv_storage", - "NanoVectorDBStorage": ".storage.nano_vector_db", - "JsonDocStatusStorage": ".storage.jsondocstatus_storage", + "NetworkXStorage": ".kg.networkx_impl", + "JsonKVStorage": ".kg.json_kv_impl", + "NanoVectorDBStorage": ".kg.nano_vector_db_impl", + "JsonDocStatusStorage": ".kg.jsondocstatus_impl", "Neo4JStorage": ".kg.neo4j_impl", "OracleKVStorage": ".kg.oracle_impl", "OracleGraphStorage": ".kg.oracle_impl", diff --git a/lightrag/storage/__init__.py b/lightrag/storage/__init__.py deleted file mode 100644 index 8b137891..00000000 --- a/lightrag/storage/__init__.py +++ /dev/null @@ -1 +0,0 @@ - From 5d97d7e42c9dc3a302f3a08b31d6fd334d9cace0 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 10:01:59 +0100 Subject: [PATCH 26/33] removed storage.py --- lightrag/storage.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 lightrag/storage.py diff --git a/lightrag/storage.py b/lightrag/storage.py deleted file mode 100644 index 91ba7bcc..00000000 --- a/lightrag/storage.py +++ /dev/null @@ -1 +0,0 @@ -# This file is not needed anymore (TODO: remove) From 315f0bf5f9cab44fd8b023927968c9866923f164 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 10:13:06 +0100 Subject: [PATCH 27/33] Added escaping to list_of_list_to_csv --- lightrag/utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 3454ea7c..9550f688 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -16,7 +16,9 @@ import numpy as np import tiktoken from lightrag.prompt import PROMPTS - +from typing import List +import csv +import io class UnlimitedSemaphore: """A context manager that allows unlimited access.""" @@ -235,9 +237,17 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data + + def list_of_list_to_csv(data: List[List[str]]) -> str: output = io.StringIO() - writer = csv.writer(output) + writer = csv.writer( + output, + quoting=csv.QUOTE_ALL, # Quote all fields + escapechar='\\', # Use backslash as escape character + quotechar='"', # Use double quotes + lineterminator='\n' # Explicit line terminator + ) writer.writerows(data) return output.getvalue() From 16d1ae77ee02d0b0d8942aa27367ce8c2c998ae6 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 10:15:30 +0100 Subject: [PATCH 28/33] fixed csv_string_to_list when data contains null --- lightrag/utils.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 9550f688..9792e251 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -253,9 +253,23 @@ def list_of_list_to_csv(data: List[List[str]]) -> str: def csv_string_to_list(csv_string: str) -> List[List[str]]: - output = io.StringIO(csv_string) - reader = csv.reader(output) - return [row for row in reader] + # Clean the string by removing NUL characters + cleaned_string = csv_string.replace('\0', '') + + output = io.StringIO(cleaned_string) + reader = csv.reader( + output, + quoting=csv.QUOTE_ALL, # Match the writer configuration + escapechar='\\', # Use backslash as escape character + quotechar='"', # Use double quotes + ) + + try: + return [row for row in reader] + except csv.Error as e: + raise ValueError(f"Failed to parse CSV string: {str(e)}") + finally: + output.close() def save_data_to_file(data, file_name): From 7957c7c7537d1adeb505010b1e50fbb5b1d3d21f Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 10:23:50 +0100 Subject: [PATCH 29/33] split the html and javascript code --- lightrag/api/static/index.html | 353 +----------------- lightrag/api/{webui => }/static/js/graph.js | 0 lightrag/api/static/js/lightrag_api.js | 353 ++++++++++++++++++ .../static/__init__.py | 0 .../static/css/__init__.py | 0 .../static/css/graph.css | 0 .../static/css/lightrag.css | 0 .../static/index.html | 0 .../static/js/__init__.py | 0 .../static/js/lightrag.js | 0 10 files changed, 354 insertions(+), 352 deletions(-) rename lightrag/api/{webui => }/static/js/graph.js (100%) create mode 100644 lightrag/api/static/js/lightrag_api.js rename lightrag/api/{webui => webui_depricated}/static/__init__.py (100%) rename lightrag/api/{webui => webui_depricated}/static/css/__init__.py (100%) rename lightrag/api/{webui => webui_depricated}/static/css/graph.css (100%) rename lightrag/api/{webui => webui_depricated}/static/css/lightrag.css (100%) rename lightrag/api/{webui => webui_depricated}/static/index.html (100%) rename lightrag/api/{webui => webui_depricated}/static/js/__init__.py (100%) rename lightrag/api/{webui => webui_depricated}/static/js/lightrag.js (100%) diff --git a/lightrag/api/static/index.html b/lightrag/api/static/index.html index 56a70ad7..60900c03 100644 --- a/lightrag/api/static/index.html +++ b/lightrag/api/static/index.html @@ -98,358 +98,7 @@ - - // Utility functions - const showToast = (message, duration = 3000) => { - const toast = document.getElementById('toast'); - toast.querySelector('div').textContent = message; - toast.classList.remove('hidden'); - setTimeout(() => toast.classList.add('hidden'), duration); - }; - - const fetchWithAuth = async (url, options = {}) => { - const headers = { - ...(options.headers || {}), - ...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {}) - }; - return fetch(url, { ...options, headers }); - }; - - // Page renderers - const pages = { - 'file-manager': () => ` -
-

File Manager

- -
- - -
- -
-

Selected Files

-
-
- - - - -
-

Indexed Files

-
-
- -
- `, - - 'query': () => ` -
-

Query Database

- -
-
- - -
- -
- - -
- - - -
-
-
- `, - - 'knowledge-graph': () => ` -
-
- - - -

Under Construction

-

Knowledge graph visualization will be available in a future update.

-
-
- `, - - 'status': () => ` -
-

System Status

-
-
-

System Health

-
-
-
-

Configuration

-
-
-
-
- `, - - 'settings': () => ` -
-

Settings

- -
-
-
- - -
- - -
-
-
- ` - }; - - // Page handlers - const handlers = { - 'file-manager': () => { - const fileInput = document.getElementById('fileInput'); - const dropZone = fileInput.parentElement.parentElement; - const fileList = document.querySelector('#fileList div'); - const indexedFiles = document.querySelector('#indexedFiles div'); - const uploadBtn = document.getElementById('uploadBtn'); - - const updateFileList = () => { - fileList.innerHTML = state.files.map(file => ` -
- ${file.name} - -
- `).join(''); - }; - - const updateIndexedFiles = async () => { - const response = await fetchWithAuth('/health'); - const data = await response.json(); - indexedFiles.innerHTML = data.indexed_files.map(file => ` -
- ${file} -
- `).join(''); - }; - - dropZone.addEventListener('dragover', (e) => { - e.preventDefault(); - dropZone.classList.add('border-blue-500'); - }); - - dropZone.addEventListener('dragleave', () => { - dropZone.classList.remove('border-blue-500'); - }); - - dropZone.addEventListener('drop', (e) => { - e.preventDefault(); - dropZone.classList.remove('border-blue-500'); - const files = Array.from(e.dataTransfer.files); - state.files.push(...files); - updateFileList(); - }); - - fileInput.addEventListener('change', () => { - state.files.push(...Array.from(fileInput.files)); - updateFileList(); - }); - - uploadBtn.addEventListener('click', async () => { - if (state.files.length === 0) { - showToast('Please select files to upload'); - return; - } - let apiKey = localStorage.getItem('apiKey') || ''; - const progress = document.getElementById('uploadProgress'); - const progressBar = progress.querySelector('div'); - const statusText = document.getElementById('uploadStatus'); - progress.classList.remove('hidden'); - - for (let i = 0; i < state.files.length; i++) { - const formData = new FormData(); - formData.append('file', state.files[i]); - - try { - await fetch('/documents/upload', { - method: 'POST', - headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, - body: formData - }); - - const percentage = ((i + 1) / state.files.length) * 100; - progressBar.style.width = `${percentage}%`; - statusText.textContent = i + 1; - } catch (error) { - console.error('Upload error:', error); - } - } - progress.classList.add('hidden'); - }); - - updateIndexedFiles(); - }, - - 'query': () => { - const queryBtn = document.getElementById('queryBtn'); - const queryInput = document.getElementById('queryInput'); - const queryMode = document.getElementById('queryMode'); - const queryResult = document.getElementById('queryResult'); - - queryBtn.addEventListener('click', async () => { - const query = queryInput.value.trim(); - if (!query) { - showToast('Please enter a query'); - return; - } - - queryBtn.disabled = true; - queryBtn.innerHTML = ` - - - - - Processing... - `; - - try { - const response = await fetchWithAuth('/query', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - query, - mode: queryMode.value, - stream: false, - only_need_context: false - }) - }); - - const data = await response.json(); - queryResult.innerHTML = marked.parse(data.response); - } catch (error) { - showToast('Error processing query'); - } finally { - queryBtn.disabled = false; - queryBtn.textContent = 'Send Query'; - } - }); - }, - - 'status': async () => { - const healthStatus = document.getElementById('healthStatus'); - const configStatus = document.getElementById('configStatus'); - - try { - const response = await fetchWithAuth('/health'); - const data = await response.json(); - - healthStatus.innerHTML = ` -
-
-
- ${data.status} -
-
-

Working Directory: ${data.working_directory}

-

Input Directory: ${data.input_directory}

-

Indexed Files: ${data.indexed_files_count}

-
-
- `; - - configStatus.innerHTML = Object.entries(data.configuration) - .map(([key, value]) => ` -
- ${key}: - ${value} -
- `).join(''); - } catch (error) { - showToast('Error fetching status'); - } - }, - - 'settings': () => { - const saveBtn = document.getElementById('saveSettings'); - const apiKeyInput = document.getElementById('apiKeyInput'); - - saveBtn.addEventListener('click', () => { - state.apiKey = apiKeyInput.value; - localStorage.setItem('apiKey', state.apiKey); - showToast('Settings saved successfully'); - }); - } - }; - - // Navigation handling - document.querySelectorAll('.nav-item').forEach(item => { - item.addEventListener('click', (e) => { - e.preventDefault(); - const page = item.dataset.page; - document.getElementById('content').innerHTML = pages[page](); - if (handlers[page]) handlers[page](); - state.currentPage = page; - }); - }); - - // Initialize with file manager - document.getElementById('content').innerHTML = pages['file-manager'](); - handlers['file-manager'](); - - // Global functions - window.removeFile = (fileName) => { - state.files = state.files.filter(file => file.name !== fileName); - document.querySelector('#fileList div').innerHTML = state.files.map(file => ` -
- ${file.name} - -
- `).join(''); - }; - diff --git a/lightrag/api/webui/static/js/graph.js b/lightrag/api/static/js/graph.js similarity index 100% rename from lightrag/api/webui/static/js/graph.js rename to lightrag/api/static/js/graph.js diff --git a/lightrag/api/static/js/lightrag_api.js b/lightrag/api/static/js/lightrag_api.js new file mode 100644 index 00000000..67e258b9 --- /dev/null +++ b/lightrag/api/static/js/lightrag_api.js @@ -0,0 +1,353 @@ +// State management +const state = { + apiKey: localStorage.getItem('apiKey') || '', + files: [], + indexedFiles: [], + currentPage: 'file-manager' +}; + +// Utility functions +const showToast = (message, duration = 3000) => { + const toast = document.getElementById('toast'); + toast.querySelector('div').textContent = message; + toast.classList.remove('hidden'); + setTimeout(() => toast.classList.add('hidden'), duration); +}; + +const fetchWithAuth = async (url, options = {}) => { + const headers = { + ...(options.headers || {}), + ...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {}) + }; + return fetch(url, { ...options, headers }); +}; + +// Page renderers +const pages = { + 'file-manager': () => ` +
+

File Manager

+ +
+ + +
+ +
+

Selected Files

+
+
+ + + + +
+

Indexed Files

+
+
+ +
+ `, + + 'query': () => ` +
+

Query Database

+ +
+
+ + +
+ +
+ + +
+ + + +
+
+
+ `, + + 'knowledge-graph': () => ` +
+
+ + + +

Under Construction

+

Knowledge graph visualization will be available in a future update.

+
+
+ `, + + 'status': () => ` +
+

System Status

+
+
+

System Health

+
+
+
+

Configuration

+
+
+
+
+ `, + + 'settings': () => ` +
+

Settings

+ +
+
+
+ + +
+ + +
+
+
+ ` +}; + +// Page handlers +const handlers = { + 'file-manager': () => { + const fileInput = document.getElementById('fileInput'); + const dropZone = fileInput.parentElement.parentElement; + const fileList = document.querySelector('#fileList div'); + const indexedFiles = document.querySelector('#indexedFiles div'); + const uploadBtn = document.getElementById('uploadBtn'); + + const updateFileList = () => { + fileList.innerHTML = state.files.map(file => ` +
+ ${file.name} + +
+ `).join(''); + }; + + const updateIndexedFiles = async () => { + const response = await fetchWithAuth('/health'); + const data = await response.json(); + indexedFiles.innerHTML = data.indexed_files.map(file => ` +
+ ${file} +
+ `).join(''); + }; + + dropZone.addEventListener('dragover', (e) => { + e.preventDefault(); + dropZone.classList.add('border-blue-500'); + }); + + dropZone.addEventListener('dragleave', () => { + dropZone.classList.remove('border-blue-500'); + }); + + dropZone.addEventListener('drop', (e) => { + e.preventDefault(); + dropZone.classList.remove('border-blue-500'); + const files = Array.from(e.dataTransfer.files); + state.files.push(...files); + updateFileList(); + }); + + fileInput.addEventListener('change', () => { + state.files.push(...Array.from(fileInput.files)); + updateFileList(); + }); + + uploadBtn.addEventListener('click', async () => { + if (state.files.length === 0) { + showToast('Please select files to upload'); + return; + } + let apiKey = localStorage.getItem('apiKey') || ''; + const progress = document.getElementById('uploadProgress'); + const progressBar = progress.querySelector('div'); + const statusText = document.getElementById('uploadStatus'); + progress.classList.remove('hidden'); + + for (let i = 0; i < state.files.length; i++) { + const formData = new FormData(); + formData.append('file', state.files[i]); + + try { + await fetch('/documents/upload', { + method: 'POST', + headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, + body: formData + }); + + const percentage = ((i + 1) / state.files.length) * 100; + progressBar.style.width = `${percentage}%`; + statusText.textContent = i + 1; + } catch (error) { + console.error('Upload error:', error); + } + } + progress.classList.add('hidden'); + }); + + updateIndexedFiles(); + }, + + 'query': () => { + const queryBtn = document.getElementById('queryBtn'); + const queryInput = document.getElementById('queryInput'); + const queryMode = document.getElementById('queryMode'); + const queryResult = document.getElementById('queryResult'); + + let apiKey = localStorage.getItem('apiKey') || ''; + + queryBtn.addEventListener('click', async () => { + const query = queryInput.value.trim(); + if (!query) { + showToast('Please enter a query'); + return; + } + + queryBtn.disabled = true; + queryBtn.innerHTML = ` + + + + + Processing... + `; + + try { + const response = await fetchWithAuth('/query', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query, + mode: queryMode.value, + stream: false, + only_need_context: false + }) + }); + + const data = await response.json(); + queryResult.innerHTML = marked.parse(data.response); + } catch (error) { + showToast('Error processing query'); + } finally { + queryBtn.disabled = false; + queryBtn.textContent = 'Send Query'; + } + }); + }, + + 'status': async () => { + const healthStatus = document.getElementById('healthStatus'); + const configStatus = document.getElementById('configStatus'); + + try { + const response = await fetchWithAuth('/health'); + const data = await response.json(); + + healthStatus.innerHTML = ` +
+
+
+ ${data.status} +
+
+

Working Directory: ${data.working_directory}

+

Input Directory: ${data.input_directory}

+

Indexed Files: ${data.indexed_files_count}

+
+
+ `; + + configStatus.innerHTML = Object.entries(data.configuration) + .map(([key, value]) => ` +
+ ${key}: + ${value} +
+ `).join(''); + } catch (error) { + showToast('Error fetching status'); + } + }, + + 'settings': () => { + const saveBtn = document.getElementById('saveSettings'); + const apiKeyInput = document.getElementById('apiKeyInput'); + + saveBtn.addEventListener('click', () => { + state.apiKey = apiKeyInput.value; + localStorage.setItem('apiKey', state.apiKey); + showToast('Settings saved successfully'); + }); + } +}; + +// Navigation handling +document.querySelectorAll('.nav-item').forEach(item => { + item.addEventListener('click', (e) => { + e.preventDefault(); + const page = item.dataset.page; + document.getElementById('content').innerHTML = pages[page](); + if (handlers[page]) handlers[page](); + state.currentPage = page; + }); +}); + +// Initialize with file manager +document.getElementById('content').innerHTML = pages['file-manager'](); +handlers['file-manager'](); + +// Global functions +window.removeFile = (fileName) => { + state.files = state.files.filter(file => file.name !== fileName); + document.querySelector('#fileList div').innerHTML = state.files.map(file => ` +
+ ${file.name} + +
+ `).join(''); +}; \ No newline at end of file diff --git a/lightrag/api/webui/static/__init__.py b/lightrag/api/webui_depricated/static/__init__.py similarity index 100% rename from lightrag/api/webui/static/__init__.py rename to lightrag/api/webui_depricated/static/__init__.py diff --git a/lightrag/api/webui/static/css/__init__.py b/lightrag/api/webui_depricated/static/css/__init__.py similarity index 100% rename from lightrag/api/webui/static/css/__init__.py rename to lightrag/api/webui_depricated/static/css/__init__.py diff --git a/lightrag/api/webui/static/css/graph.css b/lightrag/api/webui_depricated/static/css/graph.css similarity index 100% rename from lightrag/api/webui/static/css/graph.css rename to lightrag/api/webui_depricated/static/css/graph.css diff --git a/lightrag/api/webui/static/css/lightrag.css b/lightrag/api/webui_depricated/static/css/lightrag.css similarity index 100% rename from lightrag/api/webui/static/css/lightrag.css rename to lightrag/api/webui_depricated/static/css/lightrag.css diff --git a/lightrag/api/webui/static/index.html b/lightrag/api/webui_depricated/static/index.html similarity index 100% rename from lightrag/api/webui/static/index.html rename to lightrag/api/webui_depricated/static/index.html diff --git a/lightrag/api/webui/static/js/__init__.py b/lightrag/api/webui_depricated/static/js/__init__.py similarity index 100% rename from lightrag/api/webui/static/js/__init__.py rename to lightrag/api/webui_depricated/static/js/__init__.py diff --git a/lightrag/api/webui/static/js/lightrag.js b/lightrag/api/webui_depricated/static/js/lightrag.js similarity index 100% rename from lightrag/api/webui/static/js/lightrag.js rename to lightrag/api/webui_depricated/static/js/lightrag.js From 0721ee303c40548d0d515274fdee4e5c34f8f827 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 12:02:22 +0100 Subject: [PATCH 30/33] Fixed files list --- lightrag/api/lightrag_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1c1fe3a6..b558a228 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1798,12 +1798,13 @@ def create_app(args): @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" + files = doc_manager.scan_directory() return { "status": "healthy", "working_directory": str(args.working_dir), "input_directory": str(args.input_dir), - "indexed_files": doc_manager.indexed_files, - "indexed_files_count": len(doc_manager.indexed_files), + "indexed_files": files, + "indexed_files_count": len(files), "configuration": { # LLM configuration binding/host address (if applicable)/model (if applicable) "llm_binding": args.llm_binding, From 340ba407702508b13e98f30ca4539e1f4d63e281 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 12:25:59 +0100 Subject: [PATCH 31/33] Added rescan button --- lightrag/api/static/js/lightrag_api.js | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/lightrag/api/static/js/lightrag_api.js b/lightrag/api/static/js/lightrag_api.js index 67e258b9..94e85eb6 100644 --- a/lightrag/api/static/js/lightrag_api.js +++ b/lightrag/api/static/js/lightrag_api.js @@ -58,6 +58,12 @@ const pages = {

Indexed Files

+ `, @@ -225,7 +231,22 @@ const handlers = { } progress.classList.add('hidden'); }); - + rescanBtn.addEventListener('click', async () => { + let apiKey = localStorage.getItem('apiKey') || ''; + const progress = document.getElementById('uploadProgress'); + const progressBar = progress.querySelector('div'); + const statusText = document.getElementById('uploadStatus'); + progress.classList.remove('hidden'); + try { + const scan_output = await fetch('/documents/scan', { + method: 'GET', + }); + statusText.textContent = scan_output.data; + } catch (error) { + console.error('Upload error:', error); + } + progress.classList.add('hidden'); + }); updateIndexedFiles(); }, From a4156fed195232a765c46135ae969dabe01a1755 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 12:44:22 +0100 Subject: [PATCH 32/33] Fixed ui --- lightrag/api/static/js/lightrag_api.js | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lightrag/api/static/js/lightrag_api.js b/lightrag/api/static/js/lightrag_api.js index 94e85eb6..d9f75b41 100644 --- a/lightrag/api/static/js/lightrag_api.js +++ b/lightrag/api/static/js/lightrag_api.js @@ -58,13 +58,14 @@ const pages = {

Indexed Files

- + `, From e4b2a5956eb78f6c6562cb176c654bc7372cdc2e Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 27 Jan 2025 12:49:12 +0100 Subject: [PATCH 33/33] Upgraded ui --- lightrag/api/static/js/lightrag_api.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/static/js/lightrag_api.js b/lightrag/api/static/js/lightrag_api.js index d9f75b41..2b13a726 100644 --- a/lightrag/api/static/js/lightrag_api.js +++ b/lightrag/api/static/js/lightrag_api.js @@ -225,7 +225,7 @@ const handlers = { const percentage = ((i + 1) / state.files.length) * 100; progressBar.style.width = `${percentage}%`; - statusText.textContent = i + 1; + statusText.textContent = `${i + 1}/${state.files.length}`; } catch (error) { console.error('Upload error:', error); }