Merge pull request #660 from ParisNeo/main

LightRAG Enhancements PR
This commit is contained in:
zrguo
2025-01-27 23:17:03 +08:00
committed by GitHub
30 changed files with 1165 additions and 846 deletions

View File

@@ -1,4 +1,8 @@
import networkx as nx
import pipmaster as pm
if not pm.is_installed("pyvis"):
pm.install("pyvis")
from pyvis.network import Network
import random

View File

@@ -1798,12 +1798,13 @@ def create_app(args):
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
files = doc_manager.scan_directory()
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": doc_manager.indexed_files,
"indexed_files_count": len(doc_manager.indexed_files),
"indexed_files": files,
"indexed_files_count": len(files),
"configuration": {
# LLM configuration binding/host address (if applicable)/model (if applicable)
"llm_binding": args.llm_binding,

View File

@@ -1,10 +1,7 @@
ascii_colors
fastapi
nano_vectordb
nest_asyncio
numpy
ollama
openai
pipmaster
python-dotenv
python-multipart
@@ -12,5 +9,4 @@ tenacity
tiktoken
torch
tqdm
transformers
uvicorn

View File

@@ -98,358 +98,7 @@
</div>
</div>
<script>
// State management
const state = {
apiKey: localStorage.getItem('apiKey') || '',
files: [],
indexedFiles: [],
currentPage: 'file-manager'
};
<script src="/js/lightrag_api.js"></script>
// Utility functions
const showToast = (message, duration = 3000) => {
const toast = document.getElementById('toast');
toast.querySelector('div').textContent = message;
toast.classList.remove('hidden');
setTimeout(() => toast.classList.add('hidden'), duration);
};
const fetchWithAuth = async (url, options = {}) => {
const headers = {
...(options.headers || {}),
...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {})
};
return fetch(url, { ...options, headers });
};
// Page renderers
const pages = {
'file-manager': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">File Manager</h2>
<div class="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center hover:border-gray-400 transition-colors">
<input type="file" id="fileInput" multiple accept=".txt,.md,.doc,.docx,.pdf,.pptx" class="hidden">
<label for="fileInput" class="cursor-pointer">
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"/>
</svg>
<p class="mt-2 text-gray-600">Drag files here or click to select</p>
<p class="text-sm text-gray-500">Supported formats: TXT, MD, DOC, PDF, PPTX</p>
</label>
</div>
<div id="fileList" class="space-y-2">
<h3 class="text-lg font-semibold text-gray-700">Selected Files</h3>
<div class="space-y-2"></div>
</div>
<div id="uploadProgress" class="hidden mt-4">
<div class="w-full bg-gray-200 rounded-full h-2.5">
<div class="bg-blue-600 h-2.5 rounded-full" style="width: 0%"></div>
</div>
<p class="text-sm text-gray-600 mt-2"><span id="uploadStatus">0</span> files processed</p>
</div>
<button id="uploadBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Upload & Index Files
</button>
<div id="indexedFiles" class="space-y-2">
<h3 class="text-lg font-semibold text-gray-700">Indexed Files</h3>
<div class="space-y-2"></div>
</div>
</div>
`,
'query': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">Query Database</h2>
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700">Query Mode</label>
<select id="queryMode" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
<option value="hybrid">Hybrid</option>
<option value="local">Local</option>
<option value="global">Global</option>
<option value="naive">Naive</option>
</select>
</div>
<div>
<label class="block text-sm font-medium text-gray-700">Query</label>
<textarea id="queryInput" rows="4" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"></textarea>
</div>
<button id="queryBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Send Query
</button>
<div id="queryResult" class="mt-4 p-4 bg-white rounded-lg shadow"></div>
</div>
</div>
`,
'knowledge-graph': () => `
<div class="flex items-center justify-center h-full">
<div class="text-center">
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10"/>
</svg>
<h3 class="mt-2 text-sm font-medium text-gray-900">Under Construction</h3>
<p class="mt-1 text-sm text-gray-500">Knowledge graph visualization will be available in a future update.</p>
</div>
</div>
`,
'status': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">System Status</h2>
<div id="statusContent" class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="p-6 bg-white rounded-lg shadow-sm">
<h3 class="text-lg font-semibold mb-4">System Health</h3>
<div id="healthStatus"></div>
</div>
<div class="p-6 bg-white rounded-lg shadow-sm">
<h3 class="text-lg font-semibold mb-4">Configuration</h3>
<div id="configStatus"></div>
</div>
</div>
</div>
`,
'settings': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">Settings</h2>
<div class="max-w-xl">
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700">API Key</label>
<input type="password" id="apiKeyInput" value="${state.apiKey}"
class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
</div>
<button id="saveSettings" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Save Settings
</button>
</div>
</div>
</div>
`
};
// Page handlers
const handlers = {
'file-manager': () => {
const fileInput = document.getElementById('fileInput');
const dropZone = fileInput.parentElement.parentElement;
const fileList = document.querySelector('#fileList div');
const indexedFiles = document.querySelector('#indexedFiles div');
const uploadBtn = document.getElementById('uploadBtn');
const updateFileList = () => {
fileList.innerHTML = state.files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file.name}</span>
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
</svg>
</button>
</div>
`).join('');
};
const updateIndexedFiles = async () => {
const response = await fetchWithAuth('/health');
const data = await response.json();
indexedFiles.innerHTML = data.indexed_files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file}</span>
</div>
`).join('');
};
dropZone.addEventListener('dragover', (e) => {
e.preventDefault();
dropZone.classList.add('border-blue-500');
});
dropZone.addEventListener('dragleave', () => {
dropZone.classList.remove('border-blue-500');
});
dropZone.addEventListener('drop', (e) => {
e.preventDefault();
dropZone.classList.remove('border-blue-500');
const files = Array.from(e.dataTransfer.files);
state.files.push(...files);
updateFileList();
});
fileInput.addEventListener('change', () => {
state.files.push(...Array.from(fileInput.files));
updateFileList();
});
uploadBtn.addEventListener('click', async () => {
if (state.files.length === 0) {
showToast('Please select files to upload');
return;
}
let apiKey = localStorage.getItem('apiKey') || '';
const progress = document.getElementById('uploadProgress');
const progressBar = progress.querySelector('div');
const statusText = document.getElementById('uploadStatus');
progress.classList.remove('hidden');
for (let i = 0; i < state.files.length; i++) {
const formData = new FormData();
formData.append('file', state.files[i]);
try {
await fetch('/documents/upload', {
method: 'POST',
headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {},
body: formData
});
const percentage = ((i + 1) / state.files.length) * 100;
progressBar.style.width = `${percentage}%`;
statusText.textContent = i + 1;
} catch (error) {
console.error('Upload error:', error);
}
}
progress.classList.add('hidden');
});
updateIndexedFiles();
},
'query': () => {
const queryBtn = document.getElementById('queryBtn');
const queryInput = document.getElementById('queryInput');
const queryMode = document.getElementById('queryMode');
const queryResult = document.getElementById('queryResult');
queryBtn.addEventListener('click', async () => {
const query = queryInput.value.trim();
if (!query) {
showToast('Please enter a query');
return;
}
queryBtn.disabled = true;
queryBtn.innerHTML = `
<svg class="animate-spin h-5 w-5 mr-3" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4" fill="none"/>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"/>
</svg>
Processing...
`;
try {
const response = await fetchWithAuth('/query', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
query,
mode: queryMode.value,
stream: false,
only_need_context: false
})
});
const data = await response.json();
queryResult.innerHTML = marked.parse(data.response);
} catch (error) {
showToast('Error processing query');
} finally {
queryBtn.disabled = false;
queryBtn.textContent = 'Send Query';
}
});
},
'status': async () => {
const healthStatus = document.getElementById('healthStatus');
const configStatus = document.getElementById('configStatus');
try {
const response = await fetchWithAuth('/health');
const data = await response.json();
healthStatus.innerHTML = `
<div class="space-y-2">
<div class="flex items-center">
<div class="w-3 h-3 rounded-full ${data.status === 'healthy' ? 'bg-green-500' : 'bg-red-500'} mr-2"></div>
<span class="font-medium">${data.status}</span>
</div>
<div>
<p class="text-sm text-gray-600">Working Directory: ${data.working_directory}</p>
<p class="text-sm text-gray-600">Input Directory: ${data.input_directory}</p>
<p class="text-sm text-gray-600">Indexed Files: ${data.indexed_files_count}</p>
</div>
</div>
`;
configStatus.innerHTML = Object.entries(data.configuration)
.map(([key, value]) => `
<div class="mb-2">
<span class="text-sm font-medium text-gray-700">${key}:</span>
<span class="text-sm text-gray-600 ml-2">${value}</span>
</div>
`).join('');
} catch (error) {
showToast('Error fetching status');
}
},
'settings': () => {
const saveBtn = document.getElementById('saveSettings');
const apiKeyInput = document.getElementById('apiKeyInput');
saveBtn.addEventListener('click', () => {
state.apiKey = apiKeyInput.value;
localStorage.setItem('apiKey', state.apiKey);
showToast('Settings saved successfully');
});
}
};
// Navigation handling
document.querySelectorAll('.nav-item').forEach(item => {
item.addEventListener('click', (e) => {
e.preventDefault();
const page = item.dataset.page;
document.getElementById('content').innerHTML = pages[page]();
if (handlers[page]) handlers[page]();
state.currentPage = page;
});
});
// Initialize with file manager
document.getElementById('content').innerHTML = pages['file-manager']();
handlers['file-manager']();
// Global functions
window.removeFile = (fileName) => {
state.files = state.files.filter(file => file.name !== fileName);
document.querySelector('#fileList div').innerHTML = state.files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file.name}</span>
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
</svg>
</button>
</div>
`).join('');
};
</script>
</body>
</html>

View File

@@ -0,0 +1,375 @@
// State management
const state = {
apiKey: localStorage.getItem('apiKey') || '',
files: [],
indexedFiles: [],
currentPage: 'file-manager'
};
// Utility functions
const showToast = (message, duration = 3000) => {
const toast = document.getElementById('toast');
toast.querySelector('div').textContent = message;
toast.classList.remove('hidden');
setTimeout(() => toast.classList.add('hidden'), duration);
};
const fetchWithAuth = async (url, options = {}) => {
const headers = {
...(options.headers || {}),
...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {})
};
return fetch(url, { ...options, headers });
};
// Page renderers
const pages = {
'file-manager': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">File Manager</h2>
<div class="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center hover:border-gray-400 transition-colors">
<input type="file" id="fileInput" multiple accept=".txt,.md,.doc,.docx,.pdf,.pptx" class="hidden">
<label for="fileInput" class="cursor-pointer">
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"/>
</svg>
<p class="mt-2 text-gray-600">Drag files here or click to select</p>
<p class="text-sm text-gray-500">Supported formats: TXT, MD, DOC, PDF, PPTX</p>
</label>
</div>
<div id="fileList" class="space-y-2">
<h3 class="text-lg font-semibold text-gray-700">Selected Files</h3>
<div class="space-y-2"></div>
</div>
<div id="uploadProgress" class="hidden mt-4">
<div class="w-full bg-gray-200 rounded-full h-2.5">
<div class="bg-blue-600 h-2.5 rounded-full" style="width: 0%"></div>
</div>
<p class="text-sm text-gray-600 mt-2"><span id="uploadStatus">0</span> files processed</p>
</div>
<button id="uploadBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Upload & Index Files
</button>
<div id="indexedFiles" class="space-y-2">
<h3 class="text-lg font-semibold text-gray-700">Indexed Files</h3>
<div class="space-y-2"></div>
</div>
<button id="rescanBtn" class="flex items-center bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="20" height="20" fill="currentColor" class="mr-2">
<path d="M12 4a8 8 0 1 1-8 8H2.5a9.5 9.5 0 1 0 2.8-6.7L2 3v6h6L5.7 6.7A7.96 7.96 0 0 1 12 4z"/>
</svg>
Rescan Files
</button>
</div>
`,
'query': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">Query Database</h2>
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700">Query Mode</label>
<select id="queryMode" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
<option value="hybrid">Hybrid</option>
<option value="local">Local</option>
<option value="global">Global</option>
<option value="naive">Naive</option>
</select>
</div>
<div>
<label class="block text-sm font-medium text-gray-700">Query</label>
<textarea id="queryInput" rows="4" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"></textarea>
</div>
<button id="queryBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Send Query
</button>
<div id="queryResult" class="mt-4 p-4 bg-white rounded-lg shadow"></div>
</div>
</div>
`,
'knowledge-graph': () => `
<div class="flex items-center justify-center h-full">
<div class="text-center">
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10"/>
</svg>
<h3 class="mt-2 text-sm font-medium text-gray-900">Under Construction</h3>
<p class="mt-1 text-sm text-gray-500">Knowledge graph visualization will be available in a future update.</p>
</div>
</div>
`,
'status': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">System Status</h2>
<div id="statusContent" class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="p-6 bg-white rounded-lg shadow-sm">
<h3 class="text-lg font-semibold mb-4">System Health</h3>
<div id="healthStatus"></div>
</div>
<div class="p-6 bg-white rounded-lg shadow-sm">
<h3 class="text-lg font-semibold mb-4">Configuration</h3>
<div id="configStatus"></div>
</div>
</div>
</div>
`,
'settings': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">Settings</h2>
<div class="max-w-xl">
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700">API Key</label>
<input type="password" id="apiKeyInput" value="${state.apiKey}"
class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
</div>
<button id="saveSettings" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Save Settings
</button>
</div>
</div>
</div>
`
};
// Page handlers
const handlers = {
'file-manager': () => {
const fileInput = document.getElementById('fileInput');
const dropZone = fileInput.parentElement.parentElement;
const fileList = document.querySelector('#fileList div');
const indexedFiles = document.querySelector('#indexedFiles div');
const uploadBtn = document.getElementById('uploadBtn');
const updateFileList = () => {
fileList.innerHTML = state.files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file.name}</span>
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
</svg>
</button>
</div>
`).join('');
};
const updateIndexedFiles = async () => {
const response = await fetchWithAuth('/health');
const data = await response.json();
indexedFiles.innerHTML = data.indexed_files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file}</span>
</div>
`).join('');
};
dropZone.addEventListener('dragover', (e) => {
e.preventDefault();
dropZone.classList.add('border-blue-500');
});
dropZone.addEventListener('dragleave', () => {
dropZone.classList.remove('border-blue-500');
});
dropZone.addEventListener('drop', (e) => {
e.preventDefault();
dropZone.classList.remove('border-blue-500');
const files = Array.from(e.dataTransfer.files);
state.files.push(...files);
updateFileList();
});
fileInput.addEventListener('change', () => {
state.files.push(...Array.from(fileInput.files));
updateFileList();
});
uploadBtn.addEventListener('click', async () => {
if (state.files.length === 0) {
showToast('Please select files to upload');
return;
}
let apiKey = localStorage.getItem('apiKey') || '';
const progress = document.getElementById('uploadProgress');
const progressBar = progress.querySelector('div');
const statusText = document.getElementById('uploadStatus');
progress.classList.remove('hidden');
for (let i = 0; i < state.files.length; i++) {
const formData = new FormData();
formData.append('file', state.files[i]);
try {
await fetch('/documents/upload', {
method: 'POST',
headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {},
body: formData
});
const percentage = ((i + 1) / state.files.length) * 100;
progressBar.style.width = `${percentage}%`;
statusText.textContent = `${i + 1}/${state.files.length}`;
} catch (error) {
console.error('Upload error:', error);
}
}
progress.classList.add('hidden');
});
rescanBtn.addEventListener('click', async () => {
let apiKey = localStorage.getItem('apiKey') || '';
const progress = document.getElementById('uploadProgress');
const progressBar = progress.querySelector('div');
const statusText = document.getElementById('uploadStatus');
progress.classList.remove('hidden');
try {
const scan_output = await fetch('/documents/scan', {
method: 'GET',
});
statusText.textContent = scan_output.data;
} catch (error) {
console.error('Upload error:', error);
}
progress.classList.add('hidden');
});
updateIndexedFiles();
},
'query': () => {
const queryBtn = document.getElementById('queryBtn');
const queryInput = document.getElementById('queryInput');
const queryMode = document.getElementById('queryMode');
const queryResult = document.getElementById('queryResult');
let apiKey = localStorage.getItem('apiKey') || '';
queryBtn.addEventListener('click', async () => {
const query = queryInput.value.trim();
if (!query) {
showToast('Please enter a query');
return;
}
queryBtn.disabled = true;
queryBtn.innerHTML = `
<svg class="animate-spin h-5 w-5 mr-3" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4" fill="none"/>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"/>
</svg>
Processing...
`;
try {
const response = await fetchWithAuth('/query', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
query,
mode: queryMode.value,
stream: false,
only_need_context: false
})
});
const data = await response.json();
queryResult.innerHTML = marked.parse(data.response);
} catch (error) {
showToast('Error processing query');
} finally {
queryBtn.disabled = false;
queryBtn.textContent = 'Send Query';
}
});
},
'status': async () => {
const healthStatus = document.getElementById('healthStatus');
const configStatus = document.getElementById('configStatus');
try {
const response = await fetchWithAuth('/health');
const data = await response.json();
healthStatus.innerHTML = `
<div class="space-y-2">
<div class="flex items-center">
<div class="w-3 h-3 rounded-full ${data.status === 'healthy' ? 'bg-green-500' : 'bg-red-500'} mr-2"></div>
<span class="font-medium">${data.status}</span>
</div>
<div>
<p class="text-sm text-gray-600">Working Directory: ${data.working_directory}</p>
<p class="text-sm text-gray-600">Input Directory: ${data.input_directory}</p>
<p class="text-sm text-gray-600">Indexed Files: ${data.indexed_files_count}</p>
</div>
</div>
`;
configStatus.innerHTML = Object.entries(data.configuration)
.map(([key, value]) => `
<div class="mb-2">
<span class="text-sm font-medium text-gray-700">${key}:</span>
<span class="text-sm text-gray-600 ml-2">${value}</span>
</div>
`).join('');
} catch (error) {
showToast('Error fetching status');
}
},
'settings': () => {
const saveBtn = document.getElementById('saveSettings');
const apiKeyInput = document.getElementById('apiKeyInput');
saveBtn.addEventListener('click', () => {
state.apiKey = apiKeyInput.value;
localStorage.setItem('apiKey', state.apiKey);
showToast('Settings saved successfully');
});
}
};
// Navigation handling
document.querySelectorAll('.nav-item').forEach(item => {
item.addEventListener('click', (e) => {
e.preventDefault();
const page = item.dataset.page;
document.getElementById('content').innerHTML = pages[page]();
if (handlers[page]) handlers[page]();
state.currentPage = page;
});
});
// Initialize with file manager
document.getElementById('content').innerHTML = pages['file-manager']();
handlers['file-manager']();
// Global functions
window.removeFile = (fileName) => {
state.files = state.files.filter(file => file.name !== fileName);
document.querySelector('#fileList div').innerHTML = state.files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file.name}</span>
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
</svg>
</button>
</div>
`).join('');
};

View File

@@ -6,6 +6,14 @@ import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import psycopg
from psycopg.rows import namedtuple_row

137
lightrag/kg/json_kv_impl.py Normal file
View File

@@ -0,0 +1,137 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio
import os
from dataclasses import dataclass
from lightrag.utils import (
logger,
load_json,
write_json,
)
from lightrag.base import (
BaseKVStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
return left_data
async def drop(self):
self._data = {}
async def filter(self, filter_func):
"""Filter key-value pairs based on a filter function
Args:
filter_func: The filter function, which takes a value as an argument and returns a boolean value
Returns:
Dict: Key-value pairs that meet the condition
"""
result = {}
async with self._lock:
for key, value in self._data.items():
if filter_func(value):
result[key] = value
return result
async def delete(self, ids: list[str]):
"""Delete data with specified IDs
Args:
ids: List of IDs to delete
"""
async with self._lock:
for id in ids:
if id in self._data:
del self._data[id]
await self.index_done_callback()
logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")

View File

@@ -0,0 +1,128 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import os
from dataclasses import dataclass
from typing import Union, Dict
from lightrag.utils import (
logger,
load_json,
write_json,
)
from lightrag.base import (
DocStatus,
DocProcessingStatus,
DocStatusStorage,
)
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
return set(
[
k
for k in data
if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
]
)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
counts = {status: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data)
await self.index_done_callback()
return data
async def get_by_id(self, id: str):
return self._data.get(id)
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
"""Get document status by ID"""
return self._data.get(doc_id)
async def delete(self, doc_ids: list[str]):
"""Delete document status by IDs"""
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback()

View File

@@ -6,6 +6,9 @@ import numpy as np
from lightrag.utils import logger
from ..base import BaseVectorStorage
import pipmaster as pm
if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
from pymilvus import MilvusClient

View File

@@ -1,6 +1,10 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
if not pm.is_installed("pymongo"):
pm.install("pymongo")
from pymongo import MongoClient
from typing import Union
from lightrag.utils import logger

View File

@@ -0,0 +1,206 @@
"""
NanoVectorDB Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
import pipmaster as pm
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
import time
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import (
BaseVectorStorage,
)
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
current_time = time.time()
list_data = [
{
"__id__": k,
"__created_at__": current_time,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{
**dp,
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
}
for dp in results
]
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
self._client.delete(ids)
logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str):
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
if self._client.get([entity_id]):
await self.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str):
try:
relations = [
dp
for dp in self.client_storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
await self.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
self._client.save()

View File

@@ -3,6 +3,9 @@ import inspect
import os
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,

View File

@@ -0,0 +1,227 @@
"""
NetworkX Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
import networkx as nx
import numpy as np
from lightrag.utils import (
logger,
)
from lightrag.base import (
BaseGraphStorage,
)
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
for node in nodes:
if self._graph.has_node(node):
self._graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)

View File

@@ -6,6 +6,11 @@ from dataclasses import dataclass
from typing import Union
import numpy as np
import array
import pipmaster as pm
if not pm.is_installed("oracledb"):
pm.install("oracledb")
from ..utils import logger
from ..base import (

View File

@@ -6,6 +6,11 @@ import time
from dataclasses import dataclass
from typing import Union, List, Dict, Set, Any, Tuple
import numpy as np
import pipmaster as pm
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
import sys
from tqdm.asyncio import tqdm as tqdm_async

View File

@@ -1,8 +1,15 @@
import asyncio
import asyncpg
import sys
import os
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
import psycopg
from psycopg_pool import AsyncConnectionPool
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage

View File

@@ -1,6 +1,9 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
if not pm.is_installed("redis"):
pm.install("redis")
# aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis

View File

@@ -4,13 +4,18 @@ from dataclasses import dataclass
from typing import Union
import numpy as np
import pipmaster as pm
if not pm.is_installed("pymysql"):
pm.install("pymysql")
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from sqlalchemy import create_engine, text
from tqdm import tqdm
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from lightrag.utils import logger
class TiDB(object):
def __init__(self, config, **kwargs):
self.host = config.get("host", None)

View File

@@ -38,10 +38,10 @@ from .base import (
from .prompt import GRAPH_FIELD_SEP
STORAGES = {
"JsonKVStorage": ".storage",
"NanoVectorDBStorage": ".storage",
"NetworkXStorage": ".storage",
"JsonDocStatusStorage": ".storage",
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.jsondocstatus_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",

View File

@@ -1,460 +0,0 @@
import asyncio
import html
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
from typing import Any, Union, cast, Dict
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
import time
from .utils import (
logger,
load_json,
write_json,
compute_mdhash_id,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocStatus,
DocProcessingStatus,
DocStatusStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
return left_data
async def drop(self):
self._data = {}
async def filter(self, filter_func):
"""Filter key-value pairs based on a filter function
Args:
filter_func: The filter function, which takes a value as an argument and returns a boolean value
Returns:
Dict: Key-value pairs that meet the condition
"""
result = {}
async with self._lock:
for key, value in self._data.items():
if filter_func(value):
result[key] = value
return result
async def delete(self, ids: list[str]):
"""Delete data with specified IDs
Args:
ids: List of IDs to delete
"""
async with self._lock:
for id in ids:
if id in self._data:
del self._data[id]
await self.index_done_callback()
logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
current_time = time.time()
list_data = [
{
"__id__": k,
"__created_at__": current_time,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{
**dp,
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
}
for dp in results
]
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
self._client.delete(ids)
logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str):
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
if self._client.get([entity_id]):
await self.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str):
try:
relations = [
dp
for dp in self.client_storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
await self.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
self._client.save()
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
for node in nodes:
if self._graph.has_node(node):
self._graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
return set(
[
k
for k in data
if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
]
)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
counts = {status: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data)
await self.index_done_callback()
return data
async def get_by_id(self, id: str):
return self._data.get(id)
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
"""Get document status by ID"""
return self._data.get(doc_id)
async def delete(self, doc_ids: list[str]):
"""Delete document status by IDs"""
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback()

View File

@@ -16,7 +16,9 @@ import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
from typing import List
import csv
import io
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""
@@ -235,17 +237,39 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_ALL)
writer = csv.writer(
output,
quoting=csv.QUOTE_ALL, # Quote all fields
escapechar='\\', # Use backslash as escape character
quotechar='"', # Use double quotes
lineterminator='\n' # Explicit line terminator
)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string.replace("\x00", ""))
reader = csv.reader(output)
return [row for row in reader]
# Clean the string by removing NUL characters
cleaned_string = csv_string.replace('\0', '')
output = io.StringIO(cleaned_string)
reader = csv.reader(
output,
quoting=csv.QUOTE_ALL, # Match the writer configuration
escapechar='\\', # Use backslash as escape character
quotechar='"', # Use double quotes
)
try:
return [row for row in reader]
except csv.Error as e:
raise ValueError(f"Failed to parse CSV string: {str(e)}")
finally:
output.close()
def save_data_to_file(data, file_name):

View File

@@ -1,37 +1,24 @@
accelerate
aiofiles
aiohttp
asyncpg
configparser
# database packages
graspologic
gremlinpython
nano-vectordb
neo4j
networkx
graspologic
# TODO : Remove specific databases and move the installation to their corresponding files
# Use pipmaster for install if needed
# Basic modules
numpy
oracledb
pipmaster
psycopg-pool
psycopg[binary,pool]
pydantic
pymilvus
pymongo
pymysql
# File manipulation libraries
PyPDF2
python-docx
python-dotenv
python-pptx
pyvis
redis
setuptools
sqlalchemy
tenacity
@@ -39,3 +26,5 @@ tenacity
tiktoken
tqdm
xxhash
# Extra libraries are installed when needed using pipmaster