diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index 56642185..d082a170 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -1,4 +1,8 @@ import networkx as nx +import pipmaster as pm +if not pm.is_installed("pyvis"): + pm.install("pyvis") + from pyvis.network import Network import random diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 792b9435..35e4acf7 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1798,12 +1798,13 @@ def create_app(args): @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" + files = doc_manager.scan_directory() return { "status": "healthy", "working_directory": str(args.working_dir), "input_directory": str(args.input_dir), - "indexed_files": doc_manager.indexed_files, - "indexed_files_count": len(doc_manager.indexed_files), + "indexed_files": files, + "indexed_files_count": len(files), "configuration": { # LLM configuration binding/host address (if applicable)/model (if applicable) "llm_binding": args.llm_binding, diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index fc5afd58..7b2593c0 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -1,10 +1,7 @@ ascii_colors fastapi -nano_vectordb nest_asyncio numpy -ollama -openai pipmaster python-dotenv python-multipart @@ -12,5 +9,4 @@ tenacity tiktoken torch tqdm -transformers uvicorn diff --git a/lightrag/api/static/index.html b/lightrag/api/static/index.html index 56a70ad7..60900c03 100644 --- a/lightrag/api/static/index.html +++ b/lightrag/api/static/index.html @@ -98,358 +98,7 @@ - - // Utility functions - const showToast = (message, duration = 3000) => { - const toast = document.getElementById('toast'); - toast.querySelector('div').textContent = message; - toast.classList.remove('hidden'); - setTimeout(() => toast.classList.add('hidden'), duration); - }; - - const fetchWithAuth = async (url, options = {}) => { - const headers = { - ...(options.headers || {}), - ...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {}) - }; - return fetch(url, { ...options, headers }); - }; - - // Page renderers - const pages = { - 'file-manager': () => ` -
-

File Manager

- -
- - -
- -
-

Selected Files

-
-
- - - - -
-

Indexed Files

-
-
- -
- `, - - 'query': () => ` -
-

Query Database

- -
-
- - -
- -
- - -
- - - -
-
-
- `, - - 'knowledge-graph': () => ` -
-
- - - -

Under Construction

-

Knowledge graph visualization will be available in a future update.

-
-
- `, - - 'status': () => ` -
-

System Status

-
-
-

System Health

-
-
-
-

Configuration

-
-
-
-
- `, - - 'settings': () => ` -
-

Settings

- -
-
-
- - -
- - -
-
-
- ` - }; - - // Page handlers - const handlers = { - 'file-manager': () => { - const fileInput = document.getElementById('fileInput'); - const dropZone = fileInput.parentElement.parentElement; - const fileList = document.querySelector('#fileList div'); - const indexedFiles = document.querySelector('#indexedFiles div'); - const uploadBtn = document.getElementById('uploadBtn'); - - const updateFileList = () => { - fileList.innerHTML = state.files.map(file => ` -
- ${file.name} - -
- `).join(''); - }; - - const updateIndexedFiles = async () => { - const response = await fetchWithAuth('/health'); - const data = await response.json(); - indexedFiles.innerHTML = data.indexed_files.map(file => ` -
- ${file} -
- `).join(''); - }; - - dropZone.addEventListener('dragover', (e) => { - e.preventDefault(); - dropZone.classList.add('border-blue-500'); - }); - - dropZone.addEventListener('dragleave', () => { - dropZone.classList.remove('border-blue-500'); - }); - - dropZone.addEventListener('drop', (e) => { - e.preventDefault(); - dropZone.classList.remove('border-blue-500'); - const files = Array.from(e.dataTransfer.files); - state.files.push(...files); - updateFileList(); - }); - - fileInput.addEventListener('change', () => { - state.files.push(...Array.from(fileInput.files)); - updateFileList(); - }); - - uploadBtn.addEventListener('click', async () => { - if (state.files.length === 0) { - showToast('Please select files to upload'); - return; - } - let apiKey = localStorage.getItem('apiKey') || ''; - const progress = document.getElementById('uploadProgress'); - const progressBar = progress.querySelector('div'); - const statusText = document.getElementById('uploadStatus'); - progress.classList.remove('hidden'); - - for (let i = 0; i < state.files.length; i++) { - const formData = new FormData(); - formData.append('file', state.files[i]); - - try { - await fetch('/documents/upload', { - method: 'POST', - headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, - body: formData - }); - - const percentage = ((i + 1) / state.files.length) * 100; - progressBar.style.width = `${percentage}%`; - statusText.textContent = i + 1; - } catch (error) { - console.error('Upload error:', error); - } - } - progress.classList.add('hidden'); - }); - - updateIndexedFiles(); - }, - - 'query': () => { - const queryBtn = document.getElementById('queryBtn'); - const queryInput = document.getElementById('queryInput'); - const queryMode = document.getElementById('queryMode'); - const queryResult = document.getElementById('queryResult'); - - queryBtn.addEventListener('click', async () => { - const query = queryInput.value.trim(); - if (!query) { - showToast('Please enter a query'); - return; - } - - queryBtn.disabled = true; - queryBtn.innerHTML = ` - - - - - Processing... - `; - - try { - const response = await fetchWithAuth('/query', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - query, - mode: queryMode.value, - stream: false, - only_need_context: false - }) - }); - - const data = await response.json(); - queryResult.innerHTML = marked.parse(data.response); - } catch (error) { - showToast('Error processing query'); - } finally { - queryBtn.disabled = false; - queryBtn.textContent = 'Send Query'; - } - }); - }, - - 'status': async () => { - const healthStatus = document.getElementById('healthStatus'); - const configStatus = document.getElementById('configStatus'); - - try { - const response = await fetchWithAuth('/health'); - const data = await response.json(); - - healthStatus.innerHTML = ` -
-
-
- ${data.status} -
-
-

Working Directory: ${data.working_directory}

-

Input Directory: ${data.input_directory}

-

Indexed Files: ${data.indexed_files_count}

-
-
- `; - - configStatus.innerHTML = Object.entries(data.configuration) - .map(([key, value]) => ` -
- ${key}: - ${value} -
- `).join(''); - } catch (error) { - showToast('Error fetching status'); - } - }, - - 'settings': () => { - const saveBtn = document.getElementById('saveSettings'); - const apiKeyInput = document.getElementById('apiKeyInput'); - - saveBtn.addEventListener('click', () => { - state.apiKey = apiKeyInput.value; - localStorage.setItem('apiKey', state.apiKey); - showToast('Settings saved successfully'); - }); - } - }; - - // Navigation handling - document.querySelectorAll('.nav-item').forEach(item => { - item.addEventListener('click', (e) => { - e.preventDefault(); - const page = item.dataset.page; - document.getElementById('content').innerHTML = pages[page](); - if (handlers[page]) handlers[page](); - state.currentPage = page; - }); - }); - - // Initialize with file manager - document.getElementById('content').innerHTML = pages['file-manager'](); - handlers['file-manager'](); - - // Global functions - window.removeFile = (fileName) => { - state.files = state.files.filter(file => file.name !== fileName); - document.querySelector('#fileList div').innerHTML = state.files.map(file => ` -
- ${file.name} - -
- `).join(''); - }; - diff --git a/lightrag/api/webui/static/js/graph.js b/lightrag/api/static/js/graph.js similarity index 100% rename from lightrag/api/webui/static/js/graph.js rename to lightrag/api/static/js/graph.js diff --git a/lightrag/api/static/js/lightrag_api.js b/lightrag/api/static/js/lightrag_api.js new file mode 100644 index 00000000..2b13a726 --- /dev/null +++ b/lightrag/api/static/js/lightrag_api.js @@ -0,0 +1,375 @@ +// State management +const state = { + apiKey: localStorage.getItem('apiKey') || '', + files: [], + indexedFiles: [], + currentPage: 'file-manager' +}; + +// Utility functions +const showToast = (message, duration = 3000) => { + const toast = document.getElementById('toast'); + toast.querySelector('div').textContent = message; + toast.classList.remove('hidden'); + setTimeout(() => toast.classList.add('hidden'), duration); +}; + +const fetchWithAuth = async (url, options = {}) => { + const headers = { + ...(options.headers || {}), + ...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {}) + }; + return fetch(url, { ...options, headers }); +}; + +// Page renderers +const pages = { + 'file-manager': () => ` +
+

File Manager

+ +
+ + +
+ +
+

Selected Files

+
+
+ + + + +
+

Indexed Files

+
+
+ + + +
+ `, + + 'query': () => ` +
+

Query Database

+ +
+
+ + +
+ +
+ + +
+ + + +
+
+
+ `, + + 'knowledge-graph': () => ` +
+
+ + + +

Under Construction

+

Knowledge graph visualization will be available in a future update.

+
+
+ `, + + 'status': () => ` +
+

System Status

+
+
+

System Health

+
+
+
+

Configuration

+
+
+
+
+ `, + + 'settings': () => ` +
+

Settings

+ +
+
+
+ + +
+ + +
+
+
+ ` +}; + +// Page handlers +const handlers = { + 'file-manager': () => { + const fileInput = document.getElementById('fileInput'); + const dropZone = fileInput.parentElement.parentElement; + const fileList = document.querySelector('#fileList div'); + const indexedFiles = document.querySelector('#indexedFiles div'); + const uploadBtn = document.getElementById('uploadBtn'); + + const updateFileList = () => { + fileList.innerHTML = state.files.map(file => ` +
+ ${file.name} + +
+ `).join(''); + }; + + const updateIndexedFiles = async () => { + const response = await fetchWithAuth('/health'); + const data = await response.json(); + indexedFiles.innerHTML = data.indexed_files.map(file => ` +
+ ${file} +
+ `).join(''); + }; + + dropZone.addEventListener('dragover', (e) => { + e.preventDefault(); + dropZone.classList.add('border-blue-500'); + }); + + dropZone.addEventListener('dragleave', () => { + dropZone.classList.remove('border-blue-500'); + }); + + dropZone.addEventListener('drop', (e) => { + e.preventDefault(); + dropZone.classList.remove('border-blue-500'); + const files = Array.from(e.dataTransfer.files); + state.files.push(...files); + updateFileList(); + }); + + fileInput.addEventListener('change', () => { + state.files.push(...Array.from(fileInput.files)); + updateFileList(); + }); + + uploadBtn.addEventListener('click', async () => { + if (state.files.length === 0) { + showToast('Please select files to upload'); + return; + } + let apiKey = localStorage.getItem('apiKey') || ''; + const progress = document.getElementById('uploadProgress'); + const progressBar = progress.querySelector('div'); + const statusText = document.getElementById('uploadStatus'); + progress.classList.remove('hidden'); + + for (let i = 0; i < state.files.length; i++) { + const formData = new FormData(); + formData.append('file', state.files[i]); + + try { + await fetch('/documents/upload', { + method: 'POST', + headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, + body: formData + }); + + const percentage = ((i + 1) / state.files.length) * 100; + progressBar.style.width = `${percentage}%`; + statusText.textContent = `${i + 1}/${state.files.length}`; + } catch (error) { + console.error('Upload error:', error); + } + } + progress.classList.add('hidden'); + }); + rescanBtn.addEventListener('click', async () => { + let apiKey = localStorage.getItem('apiKey') || ''; + const progress = document.getElementById('uploadProgress'); + const progressBar = progress.querySelector('div'); + const statusText = document.getElementById('uploadStatus'); + progress.classList.remove('hidden'); + try { + const scan_output = await fetch('/documents/scan', { + method: 'GET', + }); + statusText.textContent = scan_output.data; + } catch (error) { + console.error('Upload error:', error); + } + progress.classList.add('hidden'); + }); + updateIndexedFiles(); + }, + + 'query': () => { + const queryBtn = document.getElementById('queryBtn'); + const queryInput = document.getElementById('queryInput'); + const queryMode = document.getElementById('queryMode'); + const queryResult = document.getElementById('queryResult'); + + let apiKey = localStorage.getItem('apiKey') || ''; + + queryBtn.addEventListener('click', async () => { + const query = queryInput.value.trim(); + if (!query) { + showToast('Please enter a query'); + return; + } + + queryBtn.disabled = true; + queryBtn.innerHTML = ` + + + + + Processing... + `; + + try { + const response = await fetchWithAuth('/query', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query, + mode: queryMode.value, + stream: false, + only_need_context: false + }) + }); + + const data = await response.json(); + queryResult.innerHTML = marked.parse(data.response); + } catch (error) { + showToast('Error processing query'); + } finally { + queryBtn.disabled = false; + queryBtn.textContent = 'Send Query'; + } + }); + }, + + 'status': async () => { + const healthStatus = document.getElementById('healthStatus'); + const configStatus = document.getElementById('configStatus'); + + try { + const response = await fetchWithAuth('/health'); + const data = await response.json(); + + healthStatus.innerHTML = ` +
+
+
+ ${data.status} +
+
+

Working Directory: ${data.working_directory}

+

Input Directory: ${data.input_directory}

+

Indexed Files: ${data.indexed_files_count}

+
+
+ `; + + configStatus.innerHTML = Object.entries(data.configuration) + .map(([key, value]) => ` +
+ ${key}: + ${value} +
+ `).join(''); + } catch (error) { + showToast('Error fetching status'); + } + }, + + 'settings': () => { + const saveBtn = document.getElementById('saveSettings'); + const apiKeyInput = document.getElementById('apiKeyInput'); + + saveBtn.addEventListener('click', () => { + state.apiKey = apiKeyInput.value; + localStorage.setItem('apiKey', state.apiKey); + showToast('Settings saved successfully'); + }); + } +}; + +// Navigation handling +document.querySelectorAll('.nav-item').forEach(item => { + item.addEventListener('click', (e) => { + e.preventDefault(); + const page = item.dataset.page; + document.getElementById('content').innerHTML = pages[page](); + if (handlers[page]) handlers[page](); + state.currentPage = page; + }); +}); + +// Initialize with file manager +document.getElementById('content').innerHTML = pages['file-manager'](); +handlers['file-manager'](); + +// Global functions +window.removeFile = (fileName) => { + state.files = state.files.filter(file => file.name !== fileName); + document.querySelector('#fileList div').innerHTML = state.files.map(file => ` +
+ ${file.name} + +
+ `).join(''); +}; \ No newline at end of file diff --git a/lightrag/api/webui/static/__init__.py b/lightrag/api/webui_depricated/static/__init__.py similarity index 100% rename from lightrag/api/webui/static/__init__.py rename to lightrag/api/webui_depricated/static/__init__.py diff --git a/lightrag/api/webui/static/css/__init__.py b/lightrag/api/webui_depricated/static/css/__init__.py similarity index 100% rename from lightrag/api/webui/static/css/__init__.py rename to lightrag/api/webui_depricated/static/css/__init__.py diff --git a/lightrag/api/webui/static/css/graph.css b/lightrag/api/webui_depricated/static/css/graph.css similarity index 100% rename from lightrag/api/webui/static/css/graph.css rename to lightrag/api/webui_depricated/static/css/graph.css diff --git a/lightrag/api/webui/static/css/lightrag.css b/lightrag/api/webui_depricated/static/css/lightrag.css similarity index 100% rename from lightrag/api/webui/static/css/lightrag.css rename to lightrag/api/webui_depricated/static/css/lightrag.css diff --git a/lightrag/api/webui/static/index.html b/lightrag/api/webui_depricated/static/index.html similarity index 100% rename from lightrag/api/webui/static/index.html rename to lightrag/api/webui_depricated/static/index.html diff --git a/lightrag/api/webui/static/js/__init__.py b/lightrag/api/webui_depricated/static/js/__init__.py similarity index 100% rename from lightrag/api/webui/static/js/__init__.py rename to lightrag/api/webui_depricated/static/js/__init__.py diff --git a/lightrag/api/webui/static/js/lightrag.js b/lightrag/api/webui_depricated/static/js/lightrag.js similarity index 100% rename from lightrag/api/webui/static/js/lightrag.js rename to lightrag/api/webui_depricated/static/js/lightrag.js diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 275f5775..df32b7cb 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -6,6 +6,14 @@ import sys from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +import pipmaster as pm + +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + import psycopg from psycopg.rows import namedtuple_row diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py new file mode 100644 index 00000000..57fe765d --- /dev/null +++ b/lightrag/kg/json_kv_impl.py @@ -0,0 +1,137 @@ +""" +JsonDocStatus Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + + +import asyncio +import os +from dataclasses import dataclass + +from lightrag.utils import ( + logger, + load_json, + write_json, +) + +from lightrag.base import ( + BaseKVStorage, +) + + +@dataclass +class JsonKVStorage(BaseKVStorage): + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + self._lock = asyncio.Lock() + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + + async def all_keys(self) -> list[str]: + return list(self._data.keys()) + + async def index_done_callback(self): + write_json(self._data, self._file_name) + + async def get_by_id(self, id): + return self._data.get(id, None) + + async def get_by_ids(self, ids, fields=None): + if fields is None: + return [self._data.get(id, None) for id in ids] + return [ + ( + {k: v for k, v in self._data[id].items() if k in fields} + if self._data.get(id, None) + else None + ) + for id in ids + ] + + async def filter_keys(self, data: list[str]) -> set[str]: + return set([s for s in data if s not in self._data]) + + async def upsert(self, data: dict[str, dict]): + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) + return left_data + + async def drop(self): + self._data = {} + + async def filter(self, filter_func): + """Filter key-value pairs based on a filter function + + Args: + filter_func: The filter function, which takes a value as an argument and returns a boolean value + + Returns: + Dict: Key-value pairs that meet the condition + """ + result = {} + async with self._lock: + for key, value in self._data.items(): + if filter_func(value): + result[key] = value + return result + + async def delete(self, ids: list[str]): + """Delete data with specified IDs + + Args: + ids: List of IDs to delete + """ + async with self._lock: + for id in ids: + if id in self._data: + del self._data[id] + await self.index_done_callback() + logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") + + diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py new file mode 100644 index 00000000..8f326170 --- /dev/null +++ b/lightrag/kg/jsondocstatus_impl.py @@ -0,0 +1,128 @@ +""" +JsonDocStatus Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + +import os +from dataclasses import dataclass +from typing import Union, Dict + +from lightrag.utils import ( + logger, + load_json, + write_json, +) + +from lightrag.base import ( + DocStatus, + DocProcessingStatus, + DocStatusStorage, +) + + +@dataclass +class JsonDocStatusStorage(DocStatusStorage): + """JSON implementation of document status storage""" + + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + logger.info(f"Loaded document status storage with {len(self._data)} records") + + async def filter_keys(self, data: list[str]) -> set[str]: + """Return keys that should be processed (not in storage or not successfully processed)""" + return set( + [ + k + for k in data + if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED + ] + ) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + counts = {status: 0 for status in DocStatus} + for doc in self._data.values(): + counts[doc["status"]] += 1 + return counts + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} + + async def index_done_callback(self): + """Save data to file after indexing""" + write_json(self._data, self._file_name) + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + self._data.update(data) + await self.index_done_callback() + return data + + async def get_by_id(self, id: str): + return self._data.get(id) + + async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: + """Get document status by ID""" + return self._data.get(doc_id) + + async def delete(self, doc_ids: list[str]): + """Delete document status by IDs""" + for doc_id in doc_ids: + self._data.pop(doc_id, None) + await self.index_done_callback() diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index bf20ffd7..905a08b5 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -6,6 +6,9 @@ import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage +import pipmaster as pm +if not pm.is_installed("pymilvus"): + pm.install("pymilvus") from pymilvus import MilvusClient diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index fbbae8c2..9515514a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,6 +1,10 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass +import pipmaster as pm +if not pm.is_installed("pymongo"): + pm.install("pymongo") + from pymongo import MongoClient from typing import Union from lightrag.utils import logger diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py new file mode 100644 index 00000000..f2372799 --- /dev/null +++ b/lightrag/kg/nano_vector_db_impl.py @@ -0,0 +1,206 @@ +""" +NanoVectorDB Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" +import asyncio +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +import numpy as np +import pipmaster as pm + +if not pm.is_installed("nano-vectordb"): + pm.install("nano-vectordb") + +from nano_vectordb import NanoVectorDB +import time + +from lightrag.utils import ( + logger, + compute_mdhash_id, +) + +from lightrag.base import ( + BaseVectorStorage, +) + + +@dataclass +class NanoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + self._client_file_name = os.path.join( + self.global_config["working_dir"], f"vdb_{self.namespace}.json" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + + current_time = time.time() + list_data = [ + { + "__id__": k, + "__created_at__": current_time, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + if len(embeddings) == len(list_data): + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + results = self._client.upsert(datas=list_data) + return results + else: + # sometimes the embedding is not returned correctly. just log it. + logger.error( + f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" + ) + + async def query(self, query: str, top_k=5): + embedding = await self.embedding_func([query]) + embedding = embedding[0] + results = self._client.query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] + return results + + @property + def client_storage(self): + return getattr(self._client, "_NanoVectorDB__storage") + + async def delete(self, ids: list[str]): + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + self._client.delete(ids) + logger.info( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + + async def delete_entity(self, entity_name: str): + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Check if the entity exists + if self._client.get([entity_id]): + await self.delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str): + try: + relations = [ + dp + for dp in self.client_storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + ids_to_delete = [relation["__id__"] for relation in relations] + + if ids_to_delete: + await self.delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def index_done_callback(self): + self._client.save() diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 4392a834..cd552122 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,6 +3,9 @@ import inspect import os from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict +import pipmaster as pm +if not pm.is_installed("neo4j"): + pm.install("neo4j") from neo4j import ( AsyncGraphDatabase, diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py new file mode 100644 index 00000000..493c551e --- /dev/null +++ b/lightrag/kg/networkx_impl.py @@ -0,0 +1,227 @@ +""" +NetworkX Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" +import html +import os +from dataclasses import dataclass +from typing import Any, Union, cast +import networkx as nx +import numpy as np + + +from lightrag.utils import ( + logger, +) + +from lightrag.base import ( + BaseGraphStorage, +) + + +@dataclass +class NetworkXStorage(BaseGraphStorage): + @staticmethod + def load_nx_graph(file_name) -> nx.Graph: + if os.path.exists(file_name): + return nx.read_graphml(file_name) + return None + + @staticmethod + def write_nx_graph(graph: nx.Graph, file_name): + logger.info( + f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" + ) + nx.write_graphml(graph, file_name) + + @staticmethod + def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + Return the largest connected component of the graph, with nodes and edges sorted in a stable way. + """ + from graspologic.utils import largest_connected_component + + graph = graph.copy() + graph = cast(nx.Graph, largest_connected_component(graph)) + node_mapping = { + node: html.unescape(node.upper().strip()) for node in graph.nodes() + } # type: ignore + graph = nx.relabel_nodes(graph, node_mapping) + return NetworkXStorage._stabilize_graph(graph) + + @staticmethod + def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + Ensure an undirected graph with the same relationships will always be read the same way. + """ + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + + sorted_nodes = graph.nodes(data=True) + sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + + fixed_graph.add_nodes_from(sorted_nodes) + edges = list(graph.edges(data=True)) + + if not graph.is_directed(): + + def _sort_source_target(edge): + source, target, edge_data = edge + if source > target: + temp = source + source = target + target = temp + return source, target, edge_data + + edges = [_sort_source_target(edge) for edge in edges] + + def _get_edge_key(source: Any, target: Any) -> str: + return f"{source} -> {target}" + + edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) + + fixed_graph.add_edges_from(edges) + return fixed_graph + + def __post_init__(self): + self._graphml_xml_file = os.path.join( + self.global_config["working_dir"], f"graph_{self.namespace}.graphml" + ) + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + self._graph = preloaded_graph or nx.Graph() + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def index_done_callback(self): + NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) + + async def has_node(self, node_id: str) -> bool: + return self._graph.has_node(node_id) + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + return self._graph.has_edge(source_node_id, target_node_id) + + async def get_node(self, node_id: str) -> Union[dict, None]: + return self._graph.nodes.get(node_id) + + async def node_degree(self, node_id: str) -> int: + return self._graph.degree(node_id) + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + return self._graph.degree(src_id) + self._graph.degree(tgt_id) + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + return self._graph.edges.get((source_node_id, target_node_id)) + + async def get_node_edges(self, source_node_id: str): + if self._graph.has_node(source_node_id): + return list(self._graph.edges(source_node_id)) + return None + + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + self._graph.add_node(node_id, **node_data) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + self._graph.add_edge(source_node_id, target_node_id, **edge_data) + + async def delete_node(self, node_id: str): + """ + Delete a node from the graph based on the specified node_id. + + :param node_id: The node_id to delete + """ + if self._graph.has_node(node_id): + self._graph.remove_node(node_id) + logger.info(f"Node {node_id} deleted from the graph.") + else: + logger.warning(f"Node {node_id} not found in the graph for deletion.") + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() + + # @TODO: NOT USED + async def _node2vec_embed(self): + from graspologic import embed + + embeddings, nodes = embed.node2vec_embed( + self._graph, + **self.global_config["node2vec_params"], + ) + + nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + return embeddings, nodes_ids + + def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node IDs to be deleted + """ + for node in nodes: + if self._graph.has_node(node): + self._graph.remove_node(node) + + def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + if self._graph.has_edge(source, target): + self._graph.remove_edge(source, target) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index f93d2816..2d1f631c 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -6,6 +6,11 @@ from dataclasses import dataclass from typing import Union import numpy as np import array +import pipmaster as pm + +if not pm.is_installed("oracledb"): + pm.install("oracledb") + from ..utils import logger from ..base import ( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 86072c9f..efeb7cf5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -6,6 +6,11 @@ import time from dataclasses import dataclass from typing import Union, List, Dict, Set, Any, Tuple import numpy as np + +import pipmaster as pm +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + import asyncpg import sys from tqdm.asyncio import tqdm as tqdm_async diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py index 274f03de..eb6e6e73 100644 --- a/lightrag/kg/postgres_impl_test.py +++ b/lightrag/kg/postgres_impl_test.py @@ -1,8 +1,15 @@ import asyncio -import asyncpg import sys import os +import pipmaster as pm +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + +import asyncpg import psycopg from psycopg_pool import AsyncConnectionPool from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index a126074d..013196e3 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,6 +1,9 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass +import pipmaster as pm +if not pm.is_installed("redis"): + pm.install("redis") # aioredis is a depricated library, replaced with redis from redis.asyncio import Redis diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 2cf698e1..8ba1de65 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -4,13 +4,18 @@ from dataclasses import dataclass from typing import Union import numpy as np +import pipmaster as pm +if not pm.is_installed("pymysql"): + pm.install("pymysql") +if not pm.is_installed("sqlalchemy"): + pm.install("sqlalchemy") + from sqlalchemy import create_engine, text from tqdm import tqdm from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage from lightrag.utils import logger - class TiDB(object): def __init__(self, config, **kwargs): self.host = config.get("host", None) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7e8a3bb7..b40eecaa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -38,10 +38,10 @@ from .base import ( from .prompt import GRAPH_FIELD_SEP STORAGES = { - "JsonKVStorage": ".storage", - "NanoVectorDBStorage": ".storage", - "NetworkXStorage": ".storage", - "JsonDocStatusStorage": ".storage", + "NetworkXStorage": ".kg.networkx_impl", + "JsonKVStorage": ".kg.json_kv_impl", + "NanoVectorDBStorage": ".kg.nano_vector_db_impl", + "JsonDocStatusStorage": ".kg.jsondocstatus_impl", "Neo4JStorage": ".kg.neo4j_impl", "OracleKVStorage": ".kg.oracle_impl", "OracleGraphStorage": ".kg.oracle_impl", diff --git a/lightrag/storage.py b/lightrag/storage.py deleted file mode 100644 index 3bee911b..00000000 --- a/lightrag/storage.py +++ /dev/null @@ -1,460 +0,0 @@ -import asyncio -import html -import os -from tqdm.asyncio import tqdm as tqdm_async -from dataclasses import dataclass -from typing import Any, Union, cast, Dict -import networkx as nx -import numpy as np - -from nano_vectordb import NanoVectorDB -import time - -from .utils import ( - logger, - load_json, - write_json, - compute_mdhash_id, -) - -from .base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, - DocStatus, - DocProcessingStatus, - DocStatusStorage, -) - - -@dataclass -class JsonKVStorage(BaseKVStorage): - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - self._lock = asyncio.Lock() - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - - async def all_keys(self) -> list[str]: - return list(self._data.keys()) - - async def index_done_callback(self): - write_json(self._data, self._file_name) - - async def get_by_id(self, id): - return self._data.get(id, None) - - async def get_by_ids(self, ids, fields=None): - if fields is None: - return [self._data.get(id, None) for id in ids] - return [ - ( - {k: v for k, v in self._data[id].items() if k in fields} - if self._data.get(id, None) - else None - ) - for id in ids - ] - - async def filter_keys(self, data: list[str]) -> set[str]: - return set([s for s in data if s not in self._data]) - - async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - return left_data - - async def drop(self): - self._data = {} - - async def filter(self, filter_func): - """Filter key-value pairs based on a filter function - - Args: - filter_func: The filter function, which takes a value as an argument and returns a boolean value - - Returns: - Dict: Key-value pairs that meet the condition - """ - result = {} - async with self._lock: - for key, value in self._data.items(): - if filter_func(value): - result[key] = value - return result - - async def delete(self, ids: list[str]): - """Delete data with specified IDs - - Args: - ids: List of IDs to delete - """ - async with self._lock: - for id in ids: - if id in self._data: - del self._data[id] - await self.index_done_callback() - logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") - - -@dataclass -class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 - - def __post_init__(self): - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) - self._max_batch_size = self.global_config["embedding_batch_num"] - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) - self.cosine_better_than_threshold = self.global_config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) - - async def upsert(self, data: dict[str, dict]): - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not len(data): - logger.warning("You insert an empty data to vector DB") - return [] - - current_time = time.time() - list_data = [ - { - "__id__": k, - "__created_at__": current_time, - **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - - async def wrapped_task(batch): - result = await self.embedding_func(batch) - pbar.update(1) - return result - - embedding_tasks = [wrapped_task(batch) for batch in batches] - pbar = tqdm_async( - total=len(embedding_tasks), desc="Generating embeddings", unit="batch" - ) - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - if len(embeddings) == len(list_data): - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) - return results - else: - # sometimes the embedding is not returned correctly. just log it. - logger.error( - f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" - ) - - async def query(self, query: str, top_k=5): - embedding = await self.embedding_func([query]) - embedding = embedding[0] - results = self._client.query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] - return results - - @property - def client_storage(self): - return getattr(self._client, "_NanoVectorDB__storage") - - async def delete(self, ids: list[str]): - """Delete vectors with specified IDs - - Args: - ids: List of vector IDs to be deleted - """ - try: - self._client.delete(ids) - logger.info( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - - async def delete_entity(self, entity_name: str): - try: - entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" - ) - # Check if the entity exists - if self._client.get([entity_id]): - await self.delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - - async def delete_entity_relation(self, entity_name: str): - try: - relations = [ - dp - for dp in self.client_storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") - ids_to_delete = [relation["__id__"] for relation in relations] - - if ids_to_delete: - await self.delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for {entity_name}: {e}") - - async def index_done_callback(self): - self._client.save() - - -@dataclass -class NetworkXStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name) -> nx.Graph: - if os.path.exists(file_name): - return nx.read_graphml(file_name) - return None - - @staticmethod - def write_nx_graph(graph: nx.Graph, file_name): - logger.info( - f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" - ) - nx.write_graphml(graph, file_name) - - @staticmethod - def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Return the largest connected component of the graph, with nodes and edges sorted in a stable way. - """ - from graspologic.utils import largest_connected_component - - graph = graph.copy() - graph = cast(nx.Graph, largest_connected_component(graph)) - node_mapping = { - node: html.unescape(node.upper().strip()) for node in graph.nodes() - } # type: ignore - graph = nx.relabel_nodes(graph, node_mapping) - return NetworkXStorage._stabilize_graph(graph) - - @staticmethod - def _stabilize_graph(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Ensure an undirected graph with the same relationships will always be read the same way. - """ - fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() - - sorted_nodes = graph.nodes(data=True) - sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) - - fixed_graph.add_nodes_from(sorted_nodes) - edges = list(graph.edges(data=True)) - - if not graph.is_directed(): - - def _sort_source_target(edge): - source, target, edge_data = edge - if source > target: - temp = source - source = target - target = temp - return source, target, edge_data - - edges = [_sort_source_target(edge) for edge in edges] - - def _get_edge_key(source: Any, target: Any) -> str: - return f"{source} -> {target}" - - edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) - - fixed_graph.add_edges_from(edges) - return fixed_graph - - def __post_init__(self): - self._graphml_xml_file = os.path.join( - self.global_config["working_dir"], f"graph_{self.namespace}.graphml" - ) - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - self._graph = preloaded_graph or nx.Graph() - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } - - async def index_done_callback(self): - NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) - - async def has_node(self, node_id: str) -> bool: - return self._graph.has_node(node_id) - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._graph.has_edge(source_node_id, target_node_id) - - async def get_node(self, node_id: str) -> Union[dict, None]: - return self._graph.nodes.get(node_id) - - async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - return self._graph.edges.get((source_node_id, target_node_id)) - - async def get_node_edges(self, source_node_id: str): - if self._graph.has_node(source_node_id): - return list(self._graph.edges(source_node_id)) - return None - - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - self._graph.add_node(node_id, **node_data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): - self._graph.add_edge(source_node_id, target_node_id, **edge_data) - - async def delete_node(self, node_id: str): - """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete - """ - if self._graph.has_node(node_id): - self._graph.remove_node(node_id) - logger.info(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") - - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - if algorithm not in self._node_embed_algorithms: - raise ValueError(f"Node embedding algorithm {algorithm} not supported") - return await self._node_embed_algorithms[algorithm]() - - # @TODO: NOT USED - async def _node2vec_embed(self): - from graspologic import embed - - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.global_config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] - return embeddings, nodes_ids - - def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node IDs to be deleted - """ - for node in nodes: - if self._graph.has_node(node): - self._graph.remove_node(node) - - def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - if self._graph.has_edge(source, target): - self._graph.remove_edge(source, target) - - -@dataclass -class JsonDocStatusStorage(DocStatusStorage): - """JSON implementation of document status storage""" - - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - logger.info(f"Loaded document status storage with {len(self._data)} records") - - async def filter_keys(self, data: list[str]) -> set[str]: - """Return keys that should be processed (not in storage or not successfully processed)""" - return set( - [ - k - for k in data - if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED - ] - ) - - async def get_status_counts(self) -> Dict[str, int]: - """Get counts of documents in each status""" - counts = {status: 0 for status in DocStatus} - for doc in self._data.values(): - counts[doc["status"]] += 1 - return counts - - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} - - async def index_done_callback(self): - """Save data to file after indexing""" - write_json(self._data, self._file_name) - - async def upsert(self, data: dict[str, dict]): - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ - self._data.update(data) - await self.index_done_callback() - return data - - async def get_by_id(self, id: str): - return self._data.get(id) - - async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: - """Get document status by ID""" - return self._data.get(doc_id) - - async def delete(self, doc_ids: list[str]): - """Delete document status by IDs""" - for doc_id in doc_ids: - self._data.pop(doc_id, None) - await self.index_done_callback() diff --git a/lightrag/utils.py b/lightrag/utils.py index ba88b7e4..9792e251 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -16,7 +16,9 @@ import numpy as np import tiktoken from lightrag.prompt import PROMPTS - +from typing import List +import csv +import io class UnlimitedSemaphore: """A context manager that allows unlimited access.""" @@ -235,17 +237,39 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data + + def list_of_list_to_csv(data: List[List[str]]) -> str: output = io.StringIO() - writer = csv.writer(output, quoting=csv.QUOTE_ALL) + writer = csv.writer( + output, + quoting=csv.QUOTE_ALL, # Quote all fields + escapechar='\\', # Use backslash as escape character + quotechar='"', # Use double quotes + lineterminator='\n' # Explicit line terminator + ) writer.writerows(data) return output.getvalue() def csv_string_to_list(csv_string: str) -> List[List[str]]: - output = io.StringIO(csv_string.replace("\x00", "")) - reader = csv.reader(output) - return [row for row in reader] + # Clean the string by removing NUL characters + cleaned_string = csv_string.replace('\0', '') + + output = io.StringIO(cleaned_string) + reader = csv.reader( + output, + quoting=csv.QUOTE_ALL, # Match the writer configuration + escapechar='\\', # Use backslash as escape character + quotechar='"', # Use double quotes + ) + + try: + return [row for row in reader] + except csv.Error as e: + raise ValueError(f"Failed to parse CSV string: {str(e)}") + finally: + output.close() def save_data_to_file(data, file_name): diff --git a/requirements.txt b/requirements.txt index c372cf9b..0f4c18ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,37 +1,24 @@ accelerate aiofiles aiohttp -asyncpg configparser # database packages -graspologic -gremlinpython -nano-vectordb -neo4j networkx +graspologic -# TODO : Remove specific databases and move the installation to their corresponding files -# Use pipmaster for install if needed +# Basic modules numpy -oracledb pipmaster -psycopg-pool -psycopg[binary,pool] pydantic -pymilvus -pymongo -pymysql - +# File manipulation libraries PyPDF2 python-docx python-dotenv python-pptx -pyvis -redis + setuptools -sqlalchemy tenacity @@ -39,3 +26,5 @@ tenacity tiktoken tqdm xxhash + +# Extra libraries are installed when needed using pipmaster