Merge branch 'main' into graph-viewer-webui
This commit is contained in:
15
config.ini.example
Normal file
15
config.ini.example
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[neo4j]
|
||||||
|
uri = neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||||
|
username = neo4j
|
||||||
|
password = your-password
|
||||||
|
|
||||||
|
[mongodb]
|
||||||
|
uri = mongodb+srv://name:password@your-cluster-address
|
||||||
|
database = lightrag
|
||||||
|
graph = false
|
||||||
|
|
||||||
|
[redis]
|
||||||
|
uri=redis://localhost:6379/1
|
||||||
|
|
||||||
|
[qdrant]
|
||||||
|
uri = http://localhost:16333
|
@@ -113,14 +113,26 @@ if milvus_uri:
|
|||||||
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
||||||
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
|
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
|
||||||
|
|
||||||
|
# Qdrant config
|
||||||
|
qdrant_uri = config.get("qdrant", "uri", fallback=None)
|
||||||
|
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
|
||||||
|
if qdrant_uri:
|
||||||
|
os.environ["QDRANT_URL"] = qdrant_uri
|
||||||
|
if qdrant_api_key:
|
||||||
|
os.environ["QDRANT_API_KEY"] = qdrant_api_key
|
||||||
|
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
|
||||||
|
|
||||||
# MongoDB config
|
# MongoDB config
|
||||||
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
||||||
mongo_database = config.get("mongodb", "LightRAG", fallback=None)
|
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
|
||||||
|
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
|
||||||
if mongo_uri:
|
if mongo_uri:
|
||||||
os.environ["MONGO_URI"] = mongo_uri
|
os.environ["MONGO_URI"] = mongo_uri
|
||||||
os.environ["MONGO_DATABASE"] = mongo_database
|
os.environ["MONGO_DATABASE"] = mongo_database
|
||||||
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
||||||
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
|
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
|
||||||
|
if mongo_graph:
|
||||||
|
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
|
||||||
|
|
||||||
|
|
||||||
def get_default_host(binding_type: str) -> str:
|
def get_default_host(binding_type: str) -> str:
|
||||||
|
@@ -1,24 +1,26 @@
|
|||||||
from enum import Enum
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
|
||||||
Literal,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Any,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
|
|
||||||
TextChunkSchema = TypedDict(
|
|
||||||
"TextChunkSchema",
|
class TextChunkSchema(TypedDict):
|
||||||
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
tokens: int
|
||||||
)
|
content: str
|
||||||
|
full_doc_id: str
|
||||||
|
chunk_order_index: int
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -57,11 +59,11 @@ class StorageNameSpace:
|
|||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""commit the storage operations after indexing"""
|
"""Commit the storage operations after indexing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def query_done_callback(self):
|
async def query_done_callback(self):
|
||||||
"""commit the storage operations after querying"""
|
"""Commit the storage operations after querying"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -84,14 +86,14 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
class BaseKVStorage(StorageNameSpace):
|
class BaseKVStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""return un-exist keys"""
|
"""Return un-exist keys"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, Any]) -> None:
|
async def upsert(self, data: dict[str, Any]) -> None:
|
||||||
|
@@ -1,16 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
from lightrag.utils import (
|
|
||||||
logger,
|
|
||||||
load_json,
|
|
||||||
write_json,
|
|
||||||
)
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
)
|
)
|
||||||
|
from lightrag.utils import (
|
||||||
|
load_json,
|
||||||
|
logger,
|
||||||
|
write_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -25,8 +25,8 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.get(id, {})
|
return self._data.get(id)
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return [
|
return [
|
||||||
@@ -38,8 +38,8 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
for id in ids
|
for id in ids
|
||||||
]
|
]
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
return set([s for s in data if s not in self._data])
|
return set(data) - set(self._data.keys())
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
|
@@ -48,21 +48,20 @@ Usage:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from lightrag.utils import (
|
|
||||||
logger,
|
|
||||||
load_json,
|
|
||||||
write_json,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
DocStatus,
|
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
|
from lightrag.utils import (
|
||||||
|
load_json,
|
||||||
|
logger,
|
||||||
|
write_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -75,15 +74,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
return set(
|
return set(data) - set(self._data.keys())
|
||||||
[
|
|
||||||
k
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
for k in data
|
result: list[dict[str, Any]] = []
|
||||||
if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
|
for id in ids:
|
||||||
]
|
data = self._data.get(id, None)
|
||||||
)
|
if data:
|
||||||
|
result.append(data)
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
@@ -94,11 +95,19 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all failed documents"""
|
"""Get all failed documents"""
|
||||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.FAILED
|
||||||
|
}
|
||||||
|
|
||||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all pending documents"""
|
"""Get all pending documents"""
|
||||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.PENDING
|
||||||
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""Save data to file after indexing"""
|
"""Save data to file after indexing"""
|
||||||
@@ -113,12 +122,8 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.get(id, {})
|
return self._data.get(id)
|
||||||
|
|
||||||
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
|
||||||
"""Get document status by ID"""
|
|
||||||
return self._data.get(doc_id)
|
|
||||||
|
|
||||||
async def delete(self, doc_ids: list[str]):
|
async def delete(self, doc_ids: list[str]):
|
||||||
"""Delete document status by IDs"""
|
"""Delete document status by IDs"""
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pipmaster as pm
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
pm.install("pymongo")
|
pm.install("pymongo")
|
||||||
@@ -10,13 +11,14 @@ if not pm.is_installed("pymongo"):
|
|||||||
if not pm.is_installed("motor"):
|
if not pm.is_installed("motor"):
|
||||||
pm.install("motor")
|
pm.install("motor")
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from typing import Any, List, Tuple, Union
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from typing import Any, Union, List, Tuple
|
|
||||||
|
|
||||||
from ..utils import logger
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from ..base import BaseKVStorage, BaseGraphStorage
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
from ..base import BaseGraphStorage, BaseKVStorage
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
|
from ..utils import logger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -29,13 +31,13 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
self._data = database.get_collection(self.namespace)
|
self._data = database.get_collection(self.namespace)
|
||||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return self._data.find_one({"_id": id})
|
return self._data.find_one({"_id": id})
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
return list(self._data.find({"_id": {"$in": ids}}))
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
existing_ids = [
|
||||||
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
|
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
|
||||||
]
|
]
|
||||||
@@ -170,7 +172,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
But typically for a direct edge, we might just do a find_one.
|
But typically for a direct edge, we might just do a find_one.
|
||||||
Below is a demonstration approach.
|
Below is a demonstration approach.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
||||||
# Then check if the target_node appears among the edges array.
|
# Then check if the target_node appears among the edges array.
|
||||||
pipeline = [
|
pipeline = [
|
||||||
|
@@ -1,27 +1,28 @@
|
|||||||
import os
|
import array
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
# import html
|
# import html
|
||||||
# import os
|
# import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import array
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("oracledb"):
|
if not pm.is_installed("oracledb"):
|
||||||
pm.install("oracledb")
|
pm.install("oracledb")
|
||||||
|
|
||||||
|
|
||||||
from ..utils import logger
|
import oracledb
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
|
from ..utils import logger
|
||||||
import oracledb
|
|
||||||
|
|
||||||
|
|
||||||
class OracleDB:
|
class OracleDB:
|
||||||
@@ -107,7 +108,7 @@ class OracleDB:
|
|||||||
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self.query("SELECT 1 FROM {k}".format(k=k))
|
await self.query(f"SELECT 1 FROM {k}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to check table {k} in Oracle database")
|
logger.error(f"Failed to check table {k} in Oracle database")
|
||||||
logger.error(f"Oracle database error: {e}")
|
logger.error(f"Oracle database error: {e}")
|
||||||
@@ -181,8 +182,8 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
"""get doc_full data based on id."""
|
"""Get doc_full data based on id."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
# print("get_by_id:"+SQL)
|
# print("get_by_id:"+SQL)
|
||||||
@@ -191,7 +192,10 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
res = {}
|
res = {}
|
||||||
for row in array_res:
|
for row in array_res:
|
||||||
res[row["id"]] = row
|
res[row["id"]] = row
|
||||||
|
if res:
|
||||||
return res
|
return res
|
||||||
|
else:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return await self.db.query(SQL, params)
|
return await self.db.query(SQL, params)
|
||||||
|
|
||||||
@@ -209,7 +213,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
"""get doc_chunks data based on id"""
|
"""Get doc_chunks data based on id"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
)
|
)
|
||||||
|
@@ -4,34 +4,35 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union, List, Dict, Set, Any, Tuple
|
from typing import Any, Dict, List, Set, Tuple, Union
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("asyncpg"):
|
if not pm.is_installed("asyncpg"):
|
||||||
pm.install("asyncpg")
|
pm.install("asyncpg")
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
import sys
|
import sys
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
|
||||||
|
import asyncpg
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
|
||||||
from ..utils import logger
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
DocStatusStorage,
|
|
||||||
DocStatus,
|
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
BaseGraphStorage,
|
DocStatus,
|
||||||
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
|
from ..utils import logger
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
import asyncio.windows_events
|
import asyncio.windows_events
|
||||||
@@ -82,7 +83,7 @@ class PostgreSQLDB:
|
|||||||
async def check_tables(self):
|
async def check_tables(self):
|
||||||
for k, v in TABLES.items():
|
for k, v in TABLES.items():
|
||||||
try:
|
try:
|
||||||
await self.query("SELECT 1 FROM {k} LIMIT 1".format(k=k))
|
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to check table {k} in PostgreSQL database")
|
logger.error(f"Failed to check table {k} in PostgreSQL database")
|
||||||
logger.error(f"PostgreSQL database error: {e}")
|
logger.error(f"PostgreSQL database error: {e}")
|
||||||
@@ -183,7 +184,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
"""Get doc_full data by id."""
|
"""Get doc_full data by id."""
|
||||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
@@ -192,9 +193,10 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
res = {}
|
res = {}
|
||||||
for row in array_res:
|
for row in array_res:
|
||||||
res[row["id"]] = row
|
res[row["id"]] = row
|
||||||
return res
|
return res if res else None
|
||||||
else:
|
else:
|
||||||
return await self.db.query(sql, params)
|
response = await self.db.query(sql, params)
|
||||||
|
return response if response else None
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||||
"""Specifically for llm_response_cache."""
|
"""Specifically for llm_response_cache."""
|
||||||
@@ -421,7 +423,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
keys = ",".join([f"'{_id}'" for _id in data])
|
keys = ",".join([f"'{_id}'" for _id in data])
|
||||||
sql = (
|
sql = (
|
||||||
@@ -435,12 +437,12 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
existed = set([element["id"] for element in result])
|
existed = set([element["id"] for element in result])
|
||||||
return set(data) - existed
|
return set(data) - existed
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
result = await self.db.query(sql, params, True)
|
result = await self.db.query(sql, params, True)
|
||||||
if result is None or result == []:
|
if result is None or result == []:
|
||||||
return {}
|
return None
|
||||||
else:
|
else:
|
||||||
return DocProcessingStatus(
|
return DocProcessingStatus(
|
||||||
content=result[0]["content"],
|
content=result[0]["content"],
|
||||||
|
127
lightrag/kg/qdrant_impl.py
Normal file
127
lightrag/kg/qdrant_impl.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from ..utils import logger
|
||||||
|
from ..base import BaseVectorStorage
|
||||||
|
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
if not pm.is_installed("qdrant_client"):
|
||||||
|
pm.install("qdrant_client")
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient, models
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mdhash_id_for_qdrant(
|
||||||
|
content: str, prefix: str = "", style: str = "simple"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a UUID based on the content and support multiple formats.
|
||||||
|
|
||||||
|
:param content: The content used to generate the UUID.
|
||||||
|
:param style: The format of the UUID, optional values are "simple", "hyphenated", "urn".
|
||||||
|
:return: A UUID that meets the requirements of Qdrant.
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
raise ValueError("Content must not be empty.")
|
||||||
|
|
||||||
|
# Use the hash value of the content to create a UUID.
|
||||||
|
hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest()
|
||||||
|
generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4)
|
||||||
|
|
||||||
|
# Return the UUID according to the specified format.
|
||||||
|
if style == "simple":
|
||||||
|
return generated_uuid.hex
|
||||||
|
elif style == "hyphenated":
|
||||||
|
return str(generated_uuid)
|
||||||
|
elif style == "urn":
|
||||||
|
return f"urn:uuid:{generated_uuid}"
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
|
@staticmethod
|
||||||
|
def create_collection_if_not_exist(
|
||||||
|
client: QdrantClient, collection_name: str, **kwargs
|
||||||
|
):
|
||||||
|
if client.collection_exists(collection_name):
|
||||||
|
return
|
||||||
|
client.create_collection(collection_name, **kwargs)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._client = QdrantClient(
|
||||||
|
url=os.environ.get("QDRANT_URL"),
|
||||||
|
api_key=os.environ.get("QDRANT_API_KEY", None),
|
||||||
|
)
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||||
|
self._client,
|
||||||
|
self.namespace,
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict]):
|
||||||
|
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
|
if not len(data):
|
||||||
|
logger.warning("You insert an empty data to vector DB")
|
||||||
|
return []
|
||||||
|
list_data = [
|
||||||
|
{
|
||||||
|
"id": k,
|
||||||
|
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||||
|
}
|
||||||
|
for k, v in data.items()
|
||||||
|
]
|
||||||
|
contents = [v["content"] for v in data.values()]
|
||||||
|
batches = [
|
||||||
|
contents[i : i + self._max_batch_size]
|
||||||
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def wrapped_task(batch):
|
||||||
|
result = await self.embedding_func(batch)
|
||||||
|
pbar.update(1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||||
|
pbar = tqdm_async(
|
||||||
|
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||||
|
|
||||||
|
embeddings = np.concatenate(embeddings_list)
|
||||||
|
|
||||||
|
list_points = []
|
||||||
|
for i, d in enumerate(list_data):
|
||||||
|
list_points.append(
|
||||||
|
models.PointStruct(
|
||||||
|
id=compute_mdhash_id_for_qdrant(d["id"]),
|
||||||
|
vector=embeddings[i],
|
||||||
|
payload=d,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self._client.upsert(
|
||||||
|
collection_name=self.namespace, points=list_points, wait=True
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def query(self, query, top_k=5):
|
||||||
|
embedding = await self.embedding_func([query])
|
||||||
|
results = self._client.search(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
query_vector=embedding[0],
|
||||||
|
limit=top_k,
|
||||||
|
with_payload=True,
|
||||||
|
)
|
||||||
|
logger.debug(f"query result: {results}")
|
||||||
|
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -21,7 +21,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||||
logger.info(f"Use Redis as KV {self.namespace}")
|
logger.info(f"Use Redis as KV {self.namespace}")
|
||||||
|
|
||||||
async def get_by_id(self, id):
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||||
return json.loads(data) if data else None
|
return json.loads(data) if data else None
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
return [json.loads(result) if result else None for result in results]
|
return [json.loads(result) if result else None for result in results]
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
pipe = self._redis.pipeline()
|
pipe = self._redis.pipeline()
|
||||||
for key in data:
|
for key in data:
|
||||||
pipe.exists(f"{self.namespace}:{key}")
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
|
@@ -14,12 +14,12 @@ if not pm.is_installed("sqlalchemy"):
|
|||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
|
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||||
from ..utils import logger
|
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
|
from ..utils import logger
|
||||||
|
|
||||||
|
|
||||||
class TiDB(object):
|
class TiDB:
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
self.host = config.get("host", None)
|
self.host = config.get("host", None)
|
||||||
self.port = config.get("port", None)
|
self.port = config.get("port", None)
|
||||||
@@ -108,12 +108,12 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
"""Fetch doc_full data by id."""
|
"""Fetch doc_full data by id."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"id": id}
|
params = {"id": id}
|
||||||
# print("get_by_id:"+SQL)
|
response = await self.db.query(SQL, params)
|
||||||
return await self.db.query(SQL, params)
|
return response if response else None
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
@@ -178,7 +178,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
"tokens": item["tokens"],
|
"tokens": item["tokens"],
|
||||||
"chunk_order_index": item["chunk_order_index"],
|
"chunk_order_index": item["chunk_order_index"],
|
||||||
"full_doc_id": item["full_doc_id"],
|
"full_doc_id": item["full_doc_id"],
|
||||||
"content_vector": f"{item['__vector__'].tolist()}",
|
"content_vector": f'{item["__vector__"].tolist()}',
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -222,8 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
"""search from tidb vector"""
|
"""Search from tidb vector"""
|
||||||
|
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
|
|
||||||
@@ -286,7 +285,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
"id": item["id"],
|
"id": item["id"],
|
||||||
"name": item["entity_name"],
|
"name": item["entity_name"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"content_vector": f"{item['content_vector'].tolist()}",
|
"content_vector": f'{item["content_vector"].tolist()}',
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
}
|
}
|
||||||
# update entity_id if node inserted by graph_storage_instance before
|
# update entity_id if node inserted by graph_storage_instance before
|
||||||
@@ -308,7 +307,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
"source_name": item["src_id"],
|
"source_name": item["src_id"],
|
||||||
"target_name": item["tgt_id"],
|
"target_name": item["tgt_id"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"content_vector": f"{item['content_vector'].tolist()}",
|
"content_vector": f'{item["content_vector"].tolist()}',
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
}
|
}
|
||||||
# update relation_id if node inserted by graph_storage_instance before
|
# update relation_id if node inserted by graph_storage_instance before
|
||||||
|
@@ -1,28 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Coroutine, Optional, Type, Union, cast
|
from typing import Any, Callable, Optional, Type, Union, cast
|
||||||
from .operate import (
|
|
||||||
chunking_by_token_size,
|
|
||||||
extract_entities,
|
|
||||||
extract_keywords_only,
|
|
||||||
kg_query,
|
|
||||||
kg_query_with_keywords,
|
|
||||||
mix_kg_vector_query,
|
|
||||||
naive_query,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
EmbeddingFunc,
|
|
||||||
compute_mdhash_id,
|
|
||||||
limit_async_func_call,
|
|
||||||
convert_response_to_json,
|
|
||||||
logger,
|
|
||||||
set_logger,
|
|
||||||
)
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
@@ -33,10 +15,25 @@ from .base import (
|
|||||||
QueryParam,
|
QueryParam,
|
||||||
StorageNameSpace,
|
StorageNameSpace,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .namespace import NameSpace, make_namespace
|
from .namespace import NameSpace, make_namespace
|
||||||
|
from .operate import (
|
||||||
|
chunking_by_token_size,
|
||||||
|
extract_entities,
|
||||||
|
extract_keywords_only,
|
||||||
|
kg_query,
|
||||||
|
kg_query_with_keywords,
|
||||||
|
mix_kg_vector_query,
|
||||||
|
naive_query,
|
||||||
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP
|
from .prompt import GRAPH_FIELD_SEP
|
||||||
|
from .utils import (
|
||||||
|
EmbeddingFunc,
|
||||||
|
compute_mdhash_id,
|
||||||
|
convert_response_to_json,
|
||||||
|
limit_async_func_call,
|
||||||
|
logger,
|
||||||
|
set_logger,
|
||||||
|
)
|
||||||
|
|
||||||
STORAGES = {
|
STORAGES = {
|
||||||
"NetworkXStorage": ".kg.networkx_impl",
|
"NetworkXStorage": ".kg.networkx_impl",
|
||||||
@@ -62,12 +59,12 @@ STORAGES = {
|
|||||||
"GremlinStorage": ".kg.gremlin_impl",
|
"GremlinStorage": ".kg.gremlin_impl",
|
||||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||||
|
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def lazy_external_import(module_name: str, class_name: str):
|
def lazy_external_import(module_name: str, class_name: str):
|
||||||
"""Lazily import a class from an external module based on the package of the caller."""
|
"""Lazily import a class from an external module based on the package of the caller."""
|
||||||
|
|
||||||
# Get the caller's module and package
|
# Get the caller's module and package
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -113,7 +110,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LightRAG:
|
class LightRAG:
|
||||||
working_dir: str = field(
|
working_dir: str = field(
|
||||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
|
||||||
)
|
)
|
||||||
# Default not to use embedding cache
|
# Default not to use embedding cache
|
||||||
embedding_cache_config: dict = field(
|
embedding_cache_config: dict = field(
|
||||||
@@ -412,7 +409,7 @@ class LightRAG:
|
|||||||
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
|
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
|
||||||
new_docs = {doc_key: {"content": full_text.strip()}}
|
new_docs = {doc_key: {"content": full_text.strip()}}
|
||||||
|
|
||||||
_add_doc_keys = await self.full_docs.filter_keys([doc_key])
|
_add_doc_keys = await self.full_docs.filter_keys(set(doc_key))
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||||
if not len(new_docs):
|
if not len(new_docs):
|
||||||
logger.warning("This document is already in the storage.")
|
logger.warning("This document is already in the storage.")
|
||||||
@@ -421,7 +418,7 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||||
|
|
||||||
inserting_chunks = {}
|
inserting_chunks: dict[str, Any] = {}
|
||||||
for chunk_text in text_chunks:
|
for chunk_text in text_chunks:
|
||||||
chunk_text_stripped = chunk_text.strip()
|
chunk_text_stripped = chunk_text.strip()
|
||||||
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
|
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
|
||||||
@@ -431,37 +428,22 @@ class LightRAG:
|
|||||||
"full_doc_id": doc_key,
|
"full_doc_id": doc_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
doc_ids = set(inserting_chunks.keys())
|
||||||
list(inserting_chunks.keys())
|
add_chunk_keys = await self.text_chunks.filter_keys(doc_ids)
|
||||||
)
|
|
||||||
inserting_chunks = {
|
inserting_chunks = {
|
||||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
k: v for k, v in inserting_chunks.items() if k in add_chunk_keys
|
||||||
}
|
}
|
||||||
if not len(inserting_chunks):
|
if not len(inserting_chunks):
|
||||||
logger.warning("All chunks are already in the storage.")
|
logger.warning("All chunks are already in the storage.")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
tasks = [
|
||||||
|
self.chunks_vdb.upsert(inserting_chunks),
|
||||||
await self.chunks_vdb.upsert(inserting_chunks)
|
self._process_entity_relation_graph(inserting_chunks),
|
||||||
|
self.full_docs.upsert(new_docs),
|
||||||
logger.info("[Entity Extraction]...")
|
self.text_chunks.upsert(inserting_chunks),
|
||||||
maybe_new_kg = await extract_entities(
|
]
|
||||||
inserting_chunks,
|
await asyncio.gather(*tasks)
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
|
||||||
entity_vdb=self.entities_vdb,
|
|
||||||
relationships_vdb=self.relationships_vdb,
|
|
||||||
global_config=asdict(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
if maybe_new_kg is None:
|
|
||||||
logger.warning("No new entities and relationships found")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.chunk_entity_relation_graph = maybe_new_kg
|
|
||||||
|
|
||||||
await self.full_docs.upsert(new_docs)
|
|
||||||
await self.text_chunks.upsert(inserting_chunks)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if update_storage:
|
if update_storage:
|
||||||
@@ -496,15 +478,12 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 3. Filter out already processed documents
|
# 3. Filter out already processed documents
|
||||||
add_doc_keys: set[str] = set()
|
|
||||||
# Get docs ids
|
# Get docs ids
|
||||||
in_process_keys = list(new_docs.keys())
|
all_new_doc_ids = set(new_docs.keys())
|
||||||
# Get in progress docs ids
|
# Exclude IDs of documents that are already in progress
|
||||||
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
|
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
|
||||||
# Exclude already in process
|
# Filter new_docs to only include documents with unique IDs
|
||||||
add_doc_keys = new_docs.keys() - excluded_ids
|
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
|
||||||
# Filter
|
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
|
||||||
|
|
||||||
if not new_docs:
|
if not new_docs:
|
||||||
logger.info("All documents have been processed or are duplicates")
|
logger.info("All documents have been processed or are duplicates")
|
||||||
@@ -535,47 +514,32 @@ class LightRAG:
|
|||||||
# Fetch failed documents
|
# Fetch failed documents
|
||||||
failed_docs = await self.doc_status.get_failed_docs()
|
failed_docs = await self.doc_status.get_failed_docs()
|
||||||
to_process_docs.update(failed_docs)
|
to_process_docs.update(failed_docs)
|
||||||
|
pendings_docs = await self.doc_status.get_pending_docs()
|
||||||
pending_docs = await self.doc_status.get_pending_docs()
|
to_process_docs.update(pendings_docs)
|
||||||
to_process_docs.update(pending_docs)
|
|
||||||
|
|
||||||
if not to_process_docs:
|
if not to_process_docs:
|
||||||
logger.info("All documents have been processed or are duplicates")
|
logger.info("All documents have been processed or are duplicates")
|
||||||
return
|
return
|
||||||
|
|
||||||
to_process_docs_ids = list(to_process_docs.keys())
|
|
||||||
|
|
||||||
# Get allready processed documents (text chunks and full docs)
|
|
||||||
text_chunks_processed_doc_ids = await self.text_chunks.filter_keys(
|
|
||||||
to_process_docs_ids
|
|
||||||
)
|
|
||||||
full_docs_processed_doc_ids = await self.full_docs.filter_keys(
|
|
||||||
to_process_docs_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. split docs into chunks, insert chunks, update doc status
|
# 2. split docs into chunks, insert chunks, update doc status
|
||||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||||
batch_docs_list = [
|
docs_batches = [
|
||||||
list(to_process_docs.items())[i : i + batch_size]
|
list(to_process_docs.items())[i : i + batch_size]
|
||||||
for i in range(0, len(to_process_docs), batch_size)
|
for i in range(0, len(to_process_docs), batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
||||||
|
|
||||||
# 3. iterate over batches
|
# 3. iterate over batches
|
||||||
tasks: dict[str, list[Coroutine[Any, Any, None]]] = {}
|
for batch_idx, docs_batch in enumerate(docs_batches):
|
||||||
for batch_idx, ids_doc_processing_status in tqdm_async(
|
|
||||||
enumerate(batch_docs_list),
|
|
||||||
desc="Process Batches",
|
|
||||||
):
|
|
||||||
# 4. iterate over batch
|
# 4. iterate over batch
|
||||||
for id_doc_processing_status in tqdm_async(
|
for doc_id_processing_status in docs_batch:
|
||||||
ids_doc_processing_status,
|
doc_id, status_doc = doc_id_processing_status
|
||||||
desc=f"Process Batch {batch_idx}",
|
|
||||||
):
|
|
||||||
id_doc, status_doc = id_doc_processing_status
|
|
||||||
# Update status in processing
|
# Update status in processing
|
||||||
|
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
id_doc: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSING,
|
"status": DocStatus.PROCESSING,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
@@ -588,7 +552,7 @@ class LightRAG:
|
|||||||
chunks: dict[str, Any] = {
|
chunks: dict[str, Any] = {
|
||||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||||
**dp,
|
**dp,
|
||||||
"full_doc_id": id_doc_processing_status,
|
"full_doc_id": doc_id,
|
||||||
}
|
}
|
||||||
for dp in self.chunking_func(
|
for dp in self.chunking_func(
|
||||||
status_doc.content,
|
status_doc.content,
|
||||||
@@ -600,28 +564,18 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Ensure chunk insertion and graph processing happen sequentially, not in parallel
|
|
||||||
await self.chunks_vdb.upsert(chunks)
|
|
||||||
await self._process_entity_relation_graph(chunks)
|
|
||||||
|
|
||||||
tasks[id_doc] = []
|
|
||||||
# Check if document already processed the doc
|
|
||||||
if id_doc not in full_docs_processed_doc_ids:
|
|
||||||
tasks[id_doc].append(
|
|
||||||
self.full_docs.upsert({id_doc: {"content": status_doc.content}})
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if chunks already processed the doc
|
|
||||||
if id_doc not in text_chunks_processed_doc_ids:
|
|
||||||
tasks[id_doc].append(self.text_chunks.upsert(chunks))
|
|
||||||
|
|
||||||
# Process document (text chunks and full docs) in parallel
|
# Process document (text chunks and full docs) in parallel
|
||||||
for id_doc_processing_status, task in tasks.items():
|
tasks = [
|
||||||
|
self.chunks_vdb.upsert(chunks),
|
||||||
|
self._process_entity_relation_graph(chunks),
|
||||||
|
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
|
||||||
|
self.text_chunks.upsert(chunks),
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*task)
|
await asyncio.gather(*tasks)
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
id_doc_processing_status: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSED,
|
"status": DocStatus.PROCESSED,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
@@ -631,12 +585,10 @@ class LightRAG:
|
|||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||||
f"Failed to process document {id_doc_processing_status}: {str(e)}"
|
|
||||||
)
|
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
id_doc_processing_status: {
|
doc_status_id: {
|
||||||
"status": DocStatus.FAILED,
|
"status": DocStatus.FAILED,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
@@ -644,6 +596,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
|
||||||
|
|
||||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -656,8 +609,9 @@ class LightRAG:
|
|||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
)
|
)
|
||||||
if new_kg is None:
|
if new_kg is None:
|
||||||
logger.info("No entities or relationships extracted!")
|
logger.info("No new entities or relationships extracted.")
|
||||||
else:
|
else:
|
||||||
|
logger.info("New entities or relationships extracted.")
|
||||||
self.chunk_entity_relation_graph = new_kg
|
self.chunk_entity_relation_graph = new_kg
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -895,7 +849,6 @@ class LightRAG:
|
|||||||
1. Extract keywords from the 'query' using new function in operate.py.
|
1. Extract keywords from the 'query' using new function in operate.py.
|
||||||
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.aquery_with_separate_keyword_extraction(query, prompt, param)
|
self.aquery_with_separate_keyword_extraction(query, prompt, param)
|
||||||
@@ -908,7 +861,6 @@ class LightRAG:
|
|||||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# STEP 1: Keyword Extraction
|
# STEP 1: Keyword Extraction
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
Reference in New Issue
Block a user