From 3eba41aab66dd31b7ea3ee3f8a732e02f4475097 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 13:24:42 +0100 Subject: [PATCH 01/17] updated clean of what implemented on BaseVectorStorage --- lightrag/kg/chroma_impl.py | 17 +++++++++++++---- lightrag/kg/faiss_impl.py | 9 +++++---- lightrag/kg/milvus_impl.py | 16 ++++++++++++++-- lightrag/kg/mongo_impl.py | 14 ++++++++++++-- lightrag/kg/nano_vector_db_impl.py | 9 +++++---- lightrag/kg/oracle_impl.py | 15 ++++++++++----- lightrag/kg/postgres_impl.py | 16 +++++++++++----- lightrag/kg/qdrant_impl.py | 16 ++++++++++++++-- lightrag/kg/tidb_impl.py | 12 ++++++++++-- 9 files changed, 94 insertions(+), 30 deletions(-) diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index cb3b59f1..e32346f9 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,6 +1,6 @@ import asyncio from dataclasses import dataclass -from typing import Union +from typing import Any import numpy as np from chromadb import HttpClient, PersistentClient from chromadb.config import Settings @@ -102,7 +102,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"ChromaDB initialization failed: {str(e)}") raise - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: logger.warning("Empty data provided to vector DB") return [] @@ -151,7 +151,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB upsert: {str(e)}") raise - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: try: embedding = await self.embedding_func([query]) @@ -183,6 +183,15 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB query: {str(e)}") raise - async def index_done_callback(self): + + async def index_done_callback(self) -> None: # ChromaDB handles persistence automatically pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 9a5f7e4e..3027f3f0 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -1,6 +1,7 @@ import os import time import asyncio +from typing import Any import faiss import json import numpy as np @@ -57,7 +58,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Attempt to load an existing index + metadata from disk self._load_faiss_index() - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -147,7 +148,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") return [m["__id__"] for m in list_data] - async def query(self, query: str, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. """ @@ -210,7 +211,7 @@ class FaissVectorDBStorage(BaseVectorStorage): f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" ) - async def delete_entity(self, entity_name: str): + async def delete_entity(self, entity_name: str) -> None: """ Delete a single entity by computing its hashed ID the same way your code does it with `compute_mdhash_id`. @@ -234,7 +235,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._remove_faiss_ids(relations) logger.debug(f"Deleted {len(relations)} relations for {entity_name}") - async def index_done_callback(self): + async def index_done_callback(self) -> None: """ Called after indexing is done (save Faiss index + metadata). """ diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f4d9d47f..d67f03b1 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -71,7 +72,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): dimension=self.embedding_func.embedding_dim, ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -106,7 +107,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): results = self._client.upsert(collection_name=self.namespace, data=list_data) return results - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, @@ -123,3 +124,14 @@ class MilvusVectorDBStorage(BaseVectorStorage): {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} for dp in results[0] ] + + async def index_done_callback(self) -> None: + pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index c216e7be..39bb9f18 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -844,7 +844,7 @@ class MongoVectorDBStorage(BaseVectorStorage): except PyMongoError as _: logger.debug("vector index already exist") - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") if not data: logger.warning("You are inserting an empty data set to vector DB") @@ -887,7 +887,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" # Generate the embedding embedding = await self.embedding_func([query]) @@ -921,6 +921,16 @@ class MongoVectorDBStorage(BaseVectorStorage): for doc in results ] + async def index_done_callback(self) -> None: + pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): """Check if the collection exists. if not, create it.""" diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 5d786646..8b931424 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -50,6 +50,7 @@ Usage: import asyncio import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -95,7 +96,7 @@ class NanoVectorDBStorage(BaseVectorStorage): self.embedding_func.embedding_dim, storage_file=self._client_file_name ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -139,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage): f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" ) - async def query(self, query: str, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) embedding = embedding[0] results = self._client.query( @@ -176,7 +177,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - async def delete_entity(self, entity_name: str): + async def delete_entity(self, entity_name: str) -> None: try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( @@ -211,7 +212,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") - async def index_done_callback(self): + async def index_done_callback(self) -> None: # Protect file write operation async with self._save_lock: self._client.save() diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 65f1060c..197d101e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -307,7 +307,7 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) - async def index_done_callback(self): + async def index_done_callback(self) -> None: if is_namespace( self.namespace, (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), @@ -330,16 +330,14 @@ class OracleVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def upsert(self, data: dict[str, dict]): - """向向量数据库中插入数据""" + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pass async def index_done_callback(self): pass #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] # 转换精度 @@ -359,6 +357,13 @@ class OracleVectorDBStorage(BaseVectorStorage): # print("vector search result:",results) return results + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError @dataclass class OracleGraphStorage(BaseGraphStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a44aefe7..5dbc6a8e 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -287,7 +287,7 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) - async def index_done_callback(self): + async def index_done_callback(self) -> None: if is_namespace( self.namespace, (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), @@ -352,7 +352,7 @@ class PGVectorStorage(BaseVectorStorage): } return upsert_sql, data - async def upsert(self, data: Dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -398,12 +398,11 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) - async def index_done_callback(self): + async def index_done_callback(self) -> None: logger.info("vector data had been saved into postgresql db!") #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) @@ -417,6 +416,13 @@ class PGVectorStorage(BaseVectorStorage): results = await self.db.query(sql, params=params, multirows=True) return results + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError @dataclass class PGDocStatusStorage(DocStatusStorage): diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 3af76328..18a50082 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -85,7 +86,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ), ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not len(data): logger.warning("You insert an empty data to vector DB") return [] @@ -130,7 +131,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) return results - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, @@ -143,3 +144,14 @@ class QdrantVectorDBStorage(BaseVectorStorage): logger.debug(f"query result: {results}") return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] + + async def index_done_callback(self) -> None: + pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 00b8003d..a5a5c80d 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -227,7 +227,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def query(self, query: str, top_k: int) -> list[dict]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] @@ -249,7 +249,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): return results ###### INSERT entities And relationships ###### - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: # ignore, upsert in TiDBKVStorage already if not len(data): logger.warning("You insert an empty data to vector DB") @@ -333,6 +333,14 @@ class TiDBVectorDBStorage(BaseVectorStorage): return await self.db.query(SQL, params, multirows=True) + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError + @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use From 71a18d1de97f7ab1d62174ff2493590e39f8d74b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 13:31:12 +0100 Subject: [PATCH 02/17] updated clean of what implemented on BaseKVStorage --- lightrag/base.py | 4 ++-- lightrag/kg/json_kv_impl.py | 6 +++--- lightrag/kg/mongo_impl.py | 7 +++++-- lightrag/kg/oracle_impl.py | 8 +++++--- lightrag/kg/postgres_impl.py | 10 ++++++---- lightrag/kg/redis_impl.py | 11 +++++++---- lightrag/kg/tidb_impl.py | 8 +++++--- 7 files changed, 33 insertions(+), 21 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 3d4fc022..8efbe8a2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -121,11 +121,11 @@ class BaseKVStorage(StorageNameSpace): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: raise NotImplementedError - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return un-exist keys""" raise NotImplementedError - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: raise NotImplementedError async def drop(self) -> None: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 3ab5b966..5683801f 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any from lightrag.base import ( BaseKVStorage, @@ -25,7 +25,7 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage): for id in ids ] - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: return set(data) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 39bb9f18..44820ecf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -60,14 +60,14 @@ class MongoKVStorage(BaseKVStorage): # Ensure collection exists create_collection_if_not_exists(uri, database.name, self._collection_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) return await cursor.to_list() - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) existing_ids = {str(x["_id"]) async for x in cursor} return data - existing_ids @@ -107,6 +107,9 @@ class MongoKVStorage(BaseKVStorage): else: return None + async def index_done_callback(self) -> None: + pass + async def drop(self) -> None: """Drop the collection""" await self._data.drop() diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 197d101e..95d888b3 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -232,7 +232,7 @@ class OracleKVStorage(BaseKVStorage): res = [{k: v} for k, v in dict_res.items()] return res - async def filter_keys(self, keys: list[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -248,7 +248,7 @@ class OracleKVStorage(BaseKVStorage): return set(keys) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { @@ -314,6 +314,8 @@ class OracleKVStorage(BaseKVStorage): ): logger.info("full doc and chunk data had been saved into oracle db!") + async def drop(self) -> None: + raise NotImplementedError @dataclass class OracleVectorDBStorage(BaseVectorStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 5dbc6a8e..98f9c495 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,7 +4,7 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import numpy as np import pipmaster as pm @@ -185,7 +185,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -240,7 +240,7 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) - async def filter_keys(self, keys: List[str]) -> Set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -261,7 +261,7 @@ class PGKVStorage(BaseKVStorage): print(params) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): pass elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -294,6 +294,8 @@ class PGKVStorage(BaseKVStorage): ): logger.info("full doc and chunk data had been saved into postgresql db!") + async def drop(self) -> None: + raise NotImplementedError @dataclass class PGVectorStorage(BaseVectorStorage): diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index ed8f46f9..f735c72a 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any, Union +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -28,7 +28,7 @@ class RedisKVStorage(BaseKVStorage): self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None @@ -39,7 +39,7 @@ class RedisKVStorage(BaseKVStorage): results = await pipe.execute() return [json.loads(result) if result else None for result in results] - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: pipe = self._redis.pipeline() for key in data: pipe.exists(f"{self.namespace}:{key}") @@ -48,7 +48,7 @@ class RedisKVStorage(BaseKVStorage): existing_ids = {data[i] for i, exists in enumerate(results) if exists} return set(data) - existing_ids - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pipe = self._redis.pipeline() for k, v in tqdm_async(data.items(), desc="Upserting"): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) @@ -61,3 +61,6 @@ class RedisKVStorage(BaseKVStorage): keys = await self._redis.keys(f"{self.namespace}:*") if keys: await self._redis.delete(*keys) + + async def index_done_callback(self) -> None: + pass \ No newline at end of file diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index a5a5c80d..6f388e7f 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -110,7 +110,7 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} @@ -125,7 +125,7 @@ class TiDBKVStorage(BaseKVStorage): ) return await self.db.query(SQL, multirows=True) - async def filter_keys(self, keys: list[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """过滤掉重复内容""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -147,7 +147,7 @@ class TiDBKVStorage(BaseKVStorage): return data ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -207,6 +207,8 @@ class TiDBKVStorage(BaseKVStorage): ): logger.info("full doc and chunk data had been saved into TiDB db!") + async def drop(self) -> None: + raise NotImplementedError @dataclass class TiDBVectorDBStorage(BaseVectorStorage): From 882190a515abde0b9c71f95d47aa36c31c9f0ed9 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 13:53:59 +0100 Subject: [PATCH 03/17] updated clean of what implemented on DocStatusStorage --- lightrag/base.py | 54 +++++++++++++++--------------------- lightrag/kg/age_impl.py | 38 +++++++++++++------------ lightrag/kg/gremlin_impl.py | 45 ++++++++++++++++-------------- lightrag/kg/mongo_impl.py | 28 +++++++------------ lightrag/kg/neo4j_impl.py | 38 ++++++++++++------------- lightrag/kg/networkx_impl.py | 30 ++++++++++---------- lightrag/kg/oracle_impl.py | 30 ++++++++++---------- lightrag/kg/postgres_impl.py | 46 +++++++++++++++--------------- lightrag/kg/tidb_impl.py | 23 +++++++++------ 9 files changed, 164 insertions(+), 168 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 8efbe8a2..3cc7646d 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,22 +92,20 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + """Query the vector storage and retrieve top_k results.""" raise NotImplementedError async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - """Use 'content' field from value for embedding, use key as id. - If embedding_func is None, use 'embedding' field from value - """ + """Insert or update vectors in the storage.""" raise NotImplementedError async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" + """Delete a single entity by its name.""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" + """Delete relations for a given entity.""" raise NotImplementedError @@ -116,9 +114,11 @@ class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc | None = None async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get value by id""" raise NotImplementedError async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get values by ids""" raise NotImplementedError async def filter_keys(self, keys: set[str]) -> set[str]: @@ -126,9 +126,11 @@ class BaseKVStorage(StorageNameSpace): raise NotImplementedError async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Upsert data""" raise NotImplementedError async def drop(self) -> None: + """Drop the storage""" raise NotImplementedError @@ -138,74 +140,62 @@ class BaseGraphStorage(StorageNameSpace): """Check if a node exists in the graph.""" async def has_node(self, node_id: str) -> bool: + """Check if an edge exists in the graph.""" raise NotImplementedError - """Check if an edge exists in the graph.""" - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """Get the degree of a node.""" raise NotImplementedError - """Get the degree of a node.""" - async def node_degree(self, node_id: str) -> int: + """Get the degree of an edge.""" raise NotImplementedError - """Get the degree of an edge.""" - async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get a node by its id.""" raise NotImplementedError - """Get a node by its id.""" - async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get an edge by its source and target node ids.""" raise NotImplementedError - """Get an edge by its source and target node ids.""" - async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get all edges connected to a node.""" raise NotImplementedError - """Get all edges connected to a node.""" async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Upsert a node into the graph.""" raise NotImplementedError - """Upsert a node into the graph.""" - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """Upsert an edge into the graph.""" raise NotImplementedError - """Upsert an edge into the graph.""" - async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: + """Delete a node from the graph.""" raise NotImplementedError - """Delete a node from the graph.""" - async def delete_node(self, node_id: str) -> None: + """Embed nodes using an algorithm.""" raise NotImplementedError - """Embed nodes using an algorithm.""" - async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: + """Get all labels in the graph.""" raise NotImplementedError("Node embedding is not used in lightrag.") - """Get all labels in the graph.""" - async def get_all_labels(self) -> list[str]: + """Get a knowledge graph of a node.""" raise NotImplementedError - """Get a knowledge graph of a node.""" - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + """Retrieve a subgraph of the knowledge graph starting from a given node.""" raise NotImplementedError diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index a6857f22..a64e4785 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -5,7 +5,8 @@ import os import sys from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union +import numpy as np import pipmaster as pm if not pm.is_installed("psycopg-pool"): @@ -15,6 +16,7 @@ if not pm.is_installed("asyncpg"): pm.install("asyncpg") +from lightrag.types import KnowledgeGraph import psycopg from psycopg.rows import namedtuple_row from psycopg_pool import AsyncConnectionPool, PoolTimeout @@ -396,7 +398,7 @@ class AGEStorage(BaseGraphStorage): ) return single_result["edge_exists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: entity_name_label = node_id.strip('"') query = """ MATCH (n:`{label}`) RETURN n @@ -454,17 +456,7 @@ class AGEStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ + ) -> dict[str, str] | None: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -488,7 +480,7 @@ class AGEStorage(BaseGraphStorage): ) return result - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information @@ -526,7 +518,7 @@ class AGEStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((AGEQueryException,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the AGE database. @@ -562,8 +554,8 @@ class AGEStorage(BaseGraphStorage): retry=retry_if_exception_type((AGEQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -619,3 +611,15 @@ class AGEStorage(BaseGraphStorage): yield connection finally: await self._driver.putconn(connection) + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index f38fd00a..77c627b6 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -3,7 +3,9 @@ import inspect import json import os from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List + +import numpy as np from gremlin_python.driver import client, serializer from gremlin_python.driver.aiohttp.transport import AiohttpTransport @@ -15,6 +17,7 @@ from tenacity import ( wait_exponential, ) +from lightrag.types import KnowledgeGraph from lightrag.utils import logger from ..base import BaseGraphStorage @@ -190,7 +193,7 @@ class GremlinStorage(BaseGraphStorage): return result[0]["has_edge"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: entity_name = GremlinStorage._fix_name(node_id) query = f"""g .V().has('graph', {self.graph_name}) @@ -252,17 +255,7 @@ class GremlinStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given names - - Args: - source_node_id (str): Name of the source nodes - target_node_id (str): Name of the target nodes - - Returns: - dict|None: Dict of found edge properties, or None if not found - """ + ) -> dict[str, str] | None: entity_name_source = GremlinStorage._fix_name(source_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id) query = f"""g @@ -286,11 +279,7 @@ class GremlinStorage(BaseGraphStorage): ) return edge_properties - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: - """ - Retrieves all edges (relationships) for a particular node identified by its name. - :return: List of tuples containing edge sources and targets - """ + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_name = GremlinStorage._fix_name(source_node_id) query = f"""g .E() @@ -316,7 +305,7 @@ class GremlinStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((GremlinServerError,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Gremlin graph. @@ -357,8 +346,8 @@ class GremlinStorage(BaseGraphStorage): retry=retry_if_exception_type((GremlinServerError,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their names. @@ -397,3 +386,17 @@ class GremlinStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 44820ecf..ce15fe29 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"): if not pm.is_installed("motor"): pm.install("motor") -from typing import Any, List, Tuple, Union +from typing import Any, List, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient from pymongo.operations import SearchIndexModel @@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """ Return the full node document (including "edges"), or None if missing. """ @@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Return the first edge dict from source_node_id to target_node_id if it exists. - Uses a single-hop $graphLookup as demonstration, though a direct find is simpler. - """ + ) -> dict[str, str] | None: pipeline = [ {"$match": {"_id": source_node_id}}, { @@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage): return e return None - async def get_node_edges( - self, source_node_id: str - ) -> Union[List[Tuple[str, str]], None]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Return a list of (source_id, target_id) for direct edges from source_node_id. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. @@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def upsert_node(self, node_id: str, node_data: dict): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Insert or update a node document. If new, create an empty edges array. """ @@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage): await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge from source_node_id -> target_node_id with optional 'relation'. If an edge with the same target exists, we remove it and re-insert with updated data. @@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def delete_node(self, node_id: str): + async def delete_node(self, node_id: str) -> None: """ 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. @@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: """ Placeholder for demonstration, raises NotImplementedError. """ @@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 15525375..f27a9645 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,7 +3,8 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, Union, Tuple, List, Dict +from typing import Any, List, Dict +import numpy as np import pipmaster as pm import configparser @@ -191,7 +192,7 @@ class Neo4JStorage(BaseGraphStorage): ) return single_result["edgeExists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. Args: @@ -252,17 +253,8 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """Find edge between two nodes identified by their labels. + ) -> dict[str, str] | None: - Args: - source_node_id (str): Label of the source node - target_node_id (str): Label of the target node - - Returns: - dict: Edge properties if found, with at least {"weight": 0.0} - None: If error occurs - """ try: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -321,7 +313,7 @@ class Neo4JStorage(BaseGraphStorage): # Return default edge properties on error return {"weight": 0.0, "source_id": None, "target_id": None} - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_label = source_node_id.strip('"') """ @@ -364,7 +356,7 @@ class Neo4JStorage(BaseGraphStorage): ) ), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Neo4j database. @@ -405,8 +397,8 @@ class Neo4JStorage(BaseGraphStorage): ), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -444,9 +436,7 @@ class Neo4JStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -603,7 +593,7 @@ class Neo4JStorage(BaseGraphStorage): await traverse(label, 0) return result - async def get_all_labels(self) -> List[str]: + async def get_all_labels(self) -> list[str]: """ Get all existing node labels in the database Returns: @@ -627,3 +617,11 @@ class Neo4JStorage(BaseGraphStorage): async for record in result: labels.append(record["label"]) return labels + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index bb84cf82..254bb0ed 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -51,11 +51,12 @@ Usage: import html import os from dataclasses import dataclass -from typing import Any, Union, cast +from typing import Any, cast import networkx as nx import numpy as np +from lightrag.types import KnowledgeGraph from lightrag.utils import ( logger, ) @@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): + async def index_done_callback(self) -> None: NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: @@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: return self._graph.has_edge(source_node_id, target_node_id) - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: return self._graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: @@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: return self._graph.edges.get((source_node_id, target_node_id)) - async def get_node_edges(self, source_node_id: str): + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id)) return None - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def delete_node(self, node_id: str): - """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete - """ + async def delete_node(self, node_id: str) -> None: if self._graph.has_node(node_id): self._graph.remove_node(node_id) logger.info(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage): for source, target in edges: if self._graph.has_edge(source, target): self._graph.remove_edge(source, target) + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 95d888b3..360a4847 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -13,6 +13,7 @@ if not pm.is_installed("oracledb"): pm.install("oracledb") +from lightrag.types import KnowledgeGraph import oracledb from ..base import ( @@ -378,9 +379,7 @@ class OracleGraphStorage(BaseGraphStorage): #################### insert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - """插入或更新节点""" - # print("go into upsert node method") + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] @@ -413,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: """插入或更新边""" # print("go into upsert edge method") source_name = source_node_id @@ -453,8 +452,7 @@ class OracleGraphStorage(BaseGraphStorage): await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - """为节点生成向量""" + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -471,7 +469,7 @@ class OracleGraphStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - async def index_done_callback(self): + async def index_done_callback(self) -> None: """写入graphhml图文件""" logger.info( "Node and edge data had been saved into oracle db already, so nothing to do here!" @@ -493,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage): return False async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - """根据源和目标节点id检查边是否存在""" SQL = SQL_TEMPLATES["has_edge"] params = { "workspace": self.db.workspace, @@ -510,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage): return False async def node_degree(self, node_id: str) -> int: - """根据节点id获取节点的度""" SQL = SQL_TEMPLATES["node_degree"] params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) @@ -528,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage): # print("Edge degree",degree) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """根据节点id获取节点数据""" SQL = SQL_TEMPLATES["get_node"] params = {"workspace": self.db.workspace, "node_id": node_id} @@ -544,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """根据源和目标节点id获取边""" + ) -> dict[str, str] | None: SQL = SQL_TEMPLATES["get_edge"] params = { "workspace": self.db.workspace, @@ -560,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage): # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) return None - async def get_node_edges(self, source_node_id: str): - """根据节点id获取节点的所有边""" + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: if await self.has_node(source_node_id): SQL = SQL_TEMPLATES["get_node_edges"] params = {"workspace": self.db.workspace, "source_node_id": source_node_id} @@ -597,6 +591,14 @@ class OracleGraphStorage(BaseGraphStorage): if res: return res + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 98f9c495..47336190 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,11 +4,13 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union import numpy as np import pipmaster as pm +from lightrag.types import KnowledgeGraph + if not pm.is_installed("asyncpg"): pm.install("asyncpg") @@ -835,7 +837,7 @@ class PGGraphStorage(BaseGraphStorage): ) return single_result["edge_exists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: label = PGGraphStorage._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) @@ -890,17 +892,7 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_id (str): Label of the source nodes - target_node_id (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ + ) -> dict[str, str] | None: src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) @@ -924,7 +916,7 @@ class PGGraphStorage(BaseGraphStorage): ) return result - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information @@ -972,14 +964,7 @@ class PGGraphStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((PGGraphQueryException,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): - """ - Upsert a node in the AGE database. - - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: label = PGGraphStorage._encode_graph_label(node_id.strip('"')) properties = node_data @@ -1010,8 +995,8 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -1053,6 +1038,19 @@ class PGGraphStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 6f388e7f..44c0d9e7 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"): if not pm.is_installed("sqlalchemy"): pm.install("sqlalchemy") +from lightrag.types import KnowledgeGraph from sqlalchemy import create_engine, text from tqdm import tqdm @@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage): self._max_batch_size = self.global_config["embedding_batch_num"] #################### upsert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] @@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: source_name = source_node_id target_name = target_node_id weight = edge_data["weight"] @@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage): } await self.db.execute(merge_sql, data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage): degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_node"] param = {"name": node_id, "workspace": self.db.workspace} return await self.db.query(sql, param) async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_edge"] param = { "source_name": source_node_id, @@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage): } return await self.db.query(sql, param) - async def get_node_edges( - self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: sql = SQL_TEMPLATES["get_node_edges"] param = {"source_name": source_node_id, "workspace": self.db.workspace} res = await self.db.query(sql, param, multirows=True) @@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage): else: return [] + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", From 931c31fa8c2d893572e3c787c1f3fbc9f683eef2 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 13:55:30 +0100 Subject: [PATCH 04/17] cleaned code --- lightrag/base.py | 6 ++++-- lightrag/kg/age_impl.py | 18 +++++++++++------- lightrag/kg/chroma_impl.py | 3 +-- lightrag/kg/gremlin_impl.py | 16 +++++++++------- lightrag/kg/json_kv_impl.py | 2 +- lightrag/kg/milvus_impl.py | 4 ++-- lightrag/kg/mongo_impl.py | 17 +++++++++++------ lightrag/kg/neo4j_impl.py | 9 +++++---- lightrag/kg/networkx_impl.py | 12 ++++++++---- lightrag/kg/oracle_impl.py | 15 +++++++++++---- lightrag/kg/postgres_impl.py | 17 +++++++++++------ lightrag/kg/qdrant_impl.py | 4 ++-- lightrag/kg/redis_impl.py | 8 ++++---- lightrag/kg/tidb_impl.py | 21 +++++++++++++-------- 14 files changed, 93 insertions(+), 59 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 3cc7646d..8e3a7ecf 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,6 +92,7 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results.""" raise NotImplementedError @@ -165,7 +166,6 @@ class BaseGraphStorage(StorageNameSpace): """Get all edges connected to a node.""" raise NotImplementedError - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Upsert a node into the graph.""" raise NotImplementedError @@ -194,7 +194,9 @@ class BaseGraphStorage(StorageNameSpace): """Get a knowledge graph of a node.""" raise NotImplementedError - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: """Retrieve a subgraph of the knowledge graph starting from a given node.""" raise NotImplementedError diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index a64e4785..37ab57d7 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -614,12 +614,16 @@ class AGEStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError - - async def get_all_labels(self) -> list[str]: + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: - raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index e32346f9..7e325abd 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -183,7 +183,6 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB query: {str(e)}") raise - async def index_done_callback(self) -> None: # ChromaDB handles persistence automatically pass @@ -194,4 +193,4 @@ class ChromaVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 77c627b6..48bf77c8 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -389,14 +389,16 @@ class GremlinStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def embed_nodes( self, algorithm: str - ) -> tuple[np.ndarray[Any, Any], list[str]]: + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError - - async def get_all_labels(self) -> list[str]: + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: - raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 5683801f..7d51ae93 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -39,7 +39,7 @@ class JsonKVStorage(BaseKVStorage): ] async def filter_keys(self, keys: set[str]) -> set[str]: - return set(data) - set(self._data.keys()) + return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index d67f03b1..703229c8 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -127,11 +127,11 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: pass - + async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index ce15fe29..463e24d2 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -68,9 +68,9 @@ class MongoKVStorage(BaseKVStorage): return await cursor.to_list() async def filter_keys(self, keys: set[str]) -> set[str]: - cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) existing_ids = {str(x["_id"]) async for x in cursor} - return data - existing_ids + return keys - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -109,7 +109,7 @@ class MongoKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: pass - + async def drop(self) -> None: """Drop the collection""" await self._data.drop() @@ -570,7 +570,9 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: """ Placeholder for demonstration, raises NotImplementedError. """ @@ -600,7 +602,9 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -918,7 +922,7 @@ class MongoVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: pass - + async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError @@ -927,6 +931,7 @@ class MongoVectorDBStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): """Check if the collection exists. if not, create it.""" client = MongoClient(uri) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f27a9645..d8e8faa8 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -254,7 +254,6 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - try: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -436,7 +435,9 @@ class Neo4JStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -620,8 +621,8 @@ class Neo4JStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 254bb0ed..109c5827 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -186,7 +186,9 @@ class NetworkXStorage(BaseGraphStorage): else: logger.warning(f"Node {node_id} not found in the graph for deletion.") - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -225,6 +227,8 @@ class NetworkXStorage(BaseGraphStorage): async def get_all_labels(self) -> list[str]: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: - raise NotImplementedError \ No newline at end of file + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 360a4847..74268a67 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -318,6 +318,7 @@ class OracleKVStorage(BaseKVStorage): async def drop(self) -> None: raise NotImplementedError + @dataclass class OracleVectorDBStorage(BaseVectorStorage): # db instance must be injected before use @@ -368,6 +369,7 @@ class OracleVectorDBStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + @dataclass class OracleGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -452,7 +454,9 @@ class OracleGraphStorage(BaseGraphStorage): await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -593,13 +597,16 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def get_all_labels(self) -> list[str]: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 47336190..77a42ad1 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -299,6 +299,7 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> None: raise NotImplementedError + @dataclass class PGVectorStorage(BaseVectorStorage): # db instance must be injected before use @@ -428,6 +429,7 @@ class PGVectorStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + @dataclass class PGDocStatusStorage(DocStatusStorage): # db instance must be injected before use @@ -1040,18 +1042,21 @@ class PGGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError - - async def get_all_labels(self) -> list[str]: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + + async def get_all_labels(self) -> list[str]: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + + NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 18a50082..eb9582e6 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -147,11 +147,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: pass - + async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index f735c72a..71e39c5c 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -41,12 +41,12 @@ class RedisKVStorage(BaseKVStorage): async def filter_keys(self, keys: set[str]) -> set[str]: pipe = self._redis.pipeline() - for key in data: + for key in keys: pipe.exists(f"{self.namespace}:{key}") results = await pipe.execute() - existing_ids = {data[i] for i, exists in enumerate(results) if exists} - return set(data) - existing_ids + existing_ids = {keys[i] for i, exists in enumerate(results) if exists} + return set(keys) - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pipe = self._redis.pipeline() @@ -63,4 +63,4 @@ class RedisKVStorage(BaseKVStorage): await self._redis.delete(*keys) async def index_done_callback(self) -> None: - pass \ No newline at end of file + pass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 44c0d9e7..27850d81 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -127,7 +127,6 @@ class TiDBKVStorage(BaseKVStorage): return await self.db.query(SQL, multirows=True) async def filter_keys(self, keys: set[str]) -> set[str]: - """过滤掉重复内容""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), id_field=namespace_to_id(self.namespace), @@ -211,6 +210,7 @@ class TiDBKVStorage(BaseKVStorage): async def drop(self) -> None: raise NotImplementedError + @dataclass class TiDBVectorDBStorage(BaseVectorStorage): # db instance must be injected before use @@ -335,7 +335,6 @@ class TiDBVectorDBStorage(BaseVectorStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) - async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError @@ -343,7 +342,8 @@ class TiDBVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" raise NotImplementedError - + + @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -420,7 +420,9 @@ class TiDBGraphStorage(BaseGraphStorage): } await self.db.execute(merge_sql, data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -481,13 +483,16 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def get_all_labels(self) -> list[str]: - raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", From 7848a38a45ad116e6ae53dd784b55a1cb5c3bbbd Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 14:11:19 +0100 Subject: [PATCH 05/17] added all abstractmethod --- lightrag/base.py | 72 +++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 8e3a7ecf..fc4702d4 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod import os from dataclasses import dataclass, field from enum import Enum @@ -79,126 +80,126 @@ class QueryParam: @dataclass -class StorageNameSpace: +class StorageNameSpace(ABC): namespace: str global_config: dict[str, Any] + @abstractmethod async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" - pass @dataclass -class BaseVectorStorage(StorageNameSpace): +class BaseVectorStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) + @abstractmethod async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results.""" - raise NotImplementedError + @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Insert or update vectors in the storage.""" - raise NotImplementedError + @abstractmethod async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name.""" - raise NotImplementedError + @abstractmethod async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity.""" - raise NotImplementedError @dataclass -class BaseKVStorage(StorageNameSpace): - embedding_func: EmbeddingFunc | None = None +class BaseKVStorage(StorageNameSpace, ABC): + embedding_func: EmbeddingFunc + @abstractmethod async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get value by id""" - raise NotImplementedError + @abstractmethod async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get values by ids""" - raise NotImplementedError + @abstractmethod async def filter_keys(self, keys: set[str]) -> set[str]: """Return un-exist keys""" - raise NotImplementedError + @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Upsert data""" - raise NotImplementedError + @abstractmethod async def drop(self) -> None: """Drop the storage""" - raise NotImplementedError @dataclass -class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc | None = None +class BaseGraphStorage(StorageNameSpace, ABC): + embedding_func: EmbeddingFunc """Check if a node exists in the graph.""" + @abstractmethod async def has_node(self, node_id: str) -> bool: """Check if an edge exists in the graph.""" - raise NotImplementedError + @abstractmethod async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """Get the degree of a node.""" - raise NotImplementedError + @abstractmethod async def node_degree(self, node_id: str) -> int: """Get the degree of an edge.""" - raise NotImplementedError + @abstractmethod async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get a node by its id.""" - raise NotImplementedError + @abstractmethod async def get_node(self, node_id: str) -> dict[str, str] | None: """Get an edge by its source and target node ids.""" - raise NotImplementedError + @abstractmethod async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get all edges connected to a node.""" - raise NotImplementedError + @abstractmethod async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Upsert a node into the graph.""" - raise NotImplementedError + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Upsert an edge into the graph.""" - raise NotImplementedError + @abstractmethod async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """Delete a node from the graph.""" - raise NotImplementedError + @abstractmethod async def delete_node(self, node_id: str) -> None: """Embed nodes using an algorithm.""" - raise NotImplementedError + @abstractmethod async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: """Get all labels in the graph.""" - raise NotImplementedError("Node embedding is not used in lightrag.") + @abstractmethod async def get_all_labels(self) -> list[str]: """Get a knowledge graph of a node.""" - raise NotImplementedError + @abstractmethod async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """Retrieve a subgraph of the knowledge graph starting from a given node.""" - raise NotImplementedError class DocStatus(str, Enum): @@ -234,29 +235,32 @@ class DocProcessingStatus: """Additional metadata""" -class DocStatusStorage(BaseKVStorage): +@dataclass +class DocStatusStorage(BaseKVStorage, ABC): """Base class for document status storage""" + @abstractmethod async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" - raise NotImplementedError + @abstractmethod async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: """Get all failed documents""" - raise NotImplementedError + @abstractmethod async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: """Get all pending documents""" raise NotImplementedError + @abstractmethod async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: """Get all processing documents""" - raise NotImplementedError + @abstractmethod async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: """Get all procesed documents""" - raise NotImplementedError + @abstractmethod async def update_doc_status(self, data: dict[str, Any]) -> None: """Updates the status of a document. By default, it calls upsert.""" await self.upsert(data) From 3fef8201c6260068bca931fab3c05cfcb6941f40 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 14:38:09 +0100 Subject: [PATCH 06/17] added final, required methods and cleaned import --- lightrag/base.py | 1 + lightrag/kg/age_impl.py | 31 ++++++------ lightrag/kg/chroma_impl.py | 16 ++++-- lightrag/kg/faiss_impl.py | 14 ++++-- lightrag/kg/gremlin_impl.py | 20 +++++--- lightrag/kg/json_doc_status_impl.py | 76 +++++------------------------ lightrag/kg/json_kv_impl.py | 2 +- lightrag/kg/milvus_impl.py | 16 +++--- lightrag/kg/mongo_impl.py | 39 +++++++++------ lightrag/kg/nano_vector_db_impl.py | 67 ++++--------------------- lightrag/kg/neo4j_impl.py | 30 +++++++----- lightrag/kg/networkx_impl.py | 62 ++++------------------- lightrag/kg/oracle_impl.py | 34 ++++++------- lightrag/kg/postgres_impl.py | 60 +++++++++++------------ lightrag/kg/qdrant_impl.py | 18 ++++--- lightrag/kg/tidb_impl.py | 39 ++++++++------- 16 files changed, 209 insertions(+), 316 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index fc4702d4..798e3176 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,6 +92,7 @@ class StorageNameSpace(ABC): @dataclass class BaseVectorStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc + cosine_better_than_threshold: float meta_fields: set[str] = field(default_factory=set) @abstractmethod diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 37ab57d7..24f70de9 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -5,21 +5,11 @@ import os import sys from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np -import pipmaster as pm - -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - from lightrag.types import KnowledgeGraph -import psycopg -from psycopg.rows import namedtuple_row -from psycopg_pool import AsyncConnectionPool, PoolTimeout + from tenacity import ( retry, retry_if_exception_type, @@ -37,6 +27,16 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +try: + import psycopg + from psycopg.rows import namedtuple_row + from psycopg_pool import AsyncConnectionPool, PoolTimeout +except ImportError as e: + raise ImportError( + "psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed." + ) from e + + class AGEQueryException(Exception): """Exception for the AGE queries.""" @@ -55,6 +55,7 @@ class AGEQueryException(Exception): return self.details +@final @dataclass class AGEStorage(BaseGraphStorage): @staticmethod @@ -100,9 +101,6 @@ class AGEStorage(BaseGraphStorage): if self._driver: await self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") - @staticmethod def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: """ @@ -627,3 +625,6 @@ class AGEStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: raise NotImplementedError + + async def index_done_callback(self) -> None: + pass diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 7e325abd..f2d2293f 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,19 +1,25 @@ import asyncio from dataclasses import dataclass -from typing import Any +from typing import Any, final import numpy as np -from chromadb import HttpClient, PersistentClient -from chromadb.config import Settings + from lightrag.base import BaseVectorStorage from lightrag.utils import logger +try: + from chromadb import HttpClient, PersistentClient + from chromadb.config import Settings +except ImportError as e: + raise ImportError( + "chromadb library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class ChromaVectorDBStorage(BaseVectorStorage): """ChromaDB vector storage implementation.""" - cosine_better_than_threshold: float = None - def __post_init__(self): try: config = self.global_config.get("vector_db_storage_cls_kwargs", {}) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 3027f3f0..e2c06afe 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -1,8 +1,8 @@ import os import time import asyncio -from typing import Any -import faiss +from typing import Any, final + import json import numpy as np from tqdm.asyncio import tqdm as tqdm_async @@ -16,7 +16,15 @@ from lightrag.base import ( BaseVectorStorage, ) +try: + import faiss +except ImportError as e: + raise ImportError( + "faiss library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class FaissVectorDBStorage(BaseVectorStorage): """ @@ -24,8 +32,6 @@ class FaissVectorDBStorage(BaseVectorStorage): Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. """ - cosine_better_than_threshold: float = None - def __post_init__(self): # Grab config values if available config = self.global_config.get("vector_db_storage_cls_kwargs", {}) diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 48bf77c8..4038be23 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -3,13 +3,11 @@ import inspect import json import os from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, final import numpy as np -from gremlin_python.driver import client, serializer -from gremlin_python.driver.aiohttp.transport import AiohttpTransport -from gremlin_python.driver.protocol import GremlinServerError + from tenacity import ( retry, retry_if_exception_type, @@ -22,7 +20,17 @@ from lightrag.utils import logger from ..base import BaseGraphStorage +try: + from gremlin_python.driver import client, serializer + from gremlin_python.driver.aiohttp.transport import AiohttpTransport + from gremlin_python.driver.protocol import GremlinServerError +except ImportError as e: + raise ImportError( + "gremlin library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class GremlinStorage(BaseGraphStorage): @staticmethod @@ -79,8 +87,8 @@ class GremlinStorage(BaseGraphStorage): if self._driver: self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + pass @staticmethod def _to_value_map(value: Any) -> str: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index fad03acc..b96a744c 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,56 +1,6 @@ -""" -JsonDocStatus Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - from dataclasses import dataclass import os -from typing import Any, Union +from typing import Any, Union, final from lightrag.base import ( DocProcessingStatus, @@ -64,6 +14,7 @@ from lightrag.utils import ( ) +@final @dataclass class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" @@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Loaded document status storage with {len(self._data)} records") - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set(data) - set(self._data.keys()) + return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] @@ -94,7 +45,6 @@ class JsonDocStatusStorage(DocStatusStorage): return counts async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() @@ -102,7 +52,6 @@ class JsonDocStatusStorage(DocStatusStorage): } async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() @@ -110,7 +59,6 @@ class JsonDocStatusStorage(DocStatusStorage): } async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processed documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() @@ -118,23 +66,16 @@ class JsonDocStatusStorage(DocStatusStorage): } async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() if v["status"] == DocStatus.PROCESSING } - async def index_done_callback(self): - """Save data to file after indexing""" + async def index_done_callback(self) -> None: write_json(self._data, self._file_name) async def upsert(self, data: dict[str, Any]) -> None: - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ self._data.update(data) await self.index_done_callback() @@ -142,7 +83,12 @@ class JsonDocStatusStorage(DocStatusStorage): return self._data.get(id) async def delete(self, doc_ids: list[str]): - """Delete document status by IDs""" for doc_id in doc_ids: self._data.pop(doc_id, None) await self.index_done_callback() + + async def drop(self) -> None: + raise NotImplementedError + + async def update_doc_status(self, data: dict[str, Any]) -> None: + raise NotImplementedError diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 7d51ae93..779c52a9 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -22,7 +22,7 @@ class JsonKVStorage(BaseKVStorage): self._lock = asyncio.Lock() logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - async def index_done_callback(self): + async def index_done_callback(self) -> None: write_json(self._data, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 703229c8..1288df07 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1,27 +1,29 @@ import asyncio import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage -import pipmaster as pm + import configparser -if not pm.is_installed("pymilvus"): - pm.install("pymilvus") -from pymilvus import MilvusClient +try: + from pymilvus import MilvusClient +except ImportError: + raise ImportError( + "pymilvus library is not installed. Please install it to proceed." + ) config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class MilvusVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - @staticmethod def create_collection_if_not_exist( client: MilvusClient, collection_name: str, **kwargs diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 463e24d2..f44332bf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,22 +1,11 @@ import os from dataclasses import dataclass import numpy as np -import pipmaster as pm import configparser from tqdm.asyncio import tqdm as tqdm_async import asyncio -if not pm.is_installed("pymongo"): - pm.install("pymongo") - -if not pm.is_installed("motor"): - pm.install("motor") - -from typing import Any, List, Union -from motor.motor_asyncio import AsyncIOMotorClient -from pymongo import MongoClient -from pymongo.operations import SearchIndexModel -from pymongo.errors import PyMongoError +from typing import Any, List, Union, final from ..base import ( BaseGraphStorage, @@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace from ..utils import logger from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +try: + from motor.motor_asyncio import AsyncIOMotorClient + from pymongo import MongoClient + from pymongo.operations import SearchIndexModel + from pymongo.errors import PyMongoError +except ImportError as e: + raise ImportError( + "motor, pymongo library is not installed. Please install it to proceed." + ) from e + config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): @@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage): await self._data.drop() +@final @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): @@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage): """Get all procesed documents""" return await self.get_docs_by_status(DocStatus.PROCESSED) + async def index_done_callback(self) -> None: + # Implement the method here + pass + async def update_doc_status(self, data: dict[str, Any]) -> None: + raise NotImplementedError + + +@final @dataclass class MongoGraphStorage(BaseGraphStorage): """ @@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage): return result + async def index_done_callback(self) -> None: + pass + +@final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 8b931424..4ab98fe6 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -1,65 +1,10 @@ -""" -NanoVectorDB Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import asyncio import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np -import pipmaster as pm -if not pm.is_installed("nano-vectordb"): - pm.install("nano-vectordb") - -from nano_vectordb import NanoVectorDB import time from lightrag.utils import ( @@ -71,11 +16,17 @@ from lightrag.base import ( BaseVectorStorage, ) +try: + from nano_vectordb import NanoVectorDB +except ImportError as e: + raise ImportError( + "nano-vectordb library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - def __post_init__(self): # Initialize lock only for file operations self._save_lock = asyncio.Lock() diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d8e8faa8..8d078af0 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,21 +3,11 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, List, Dict +from typing import Any, List, Dict, final import numpy as np -import pipmaster as pm import configparser -if not pm.is_installed("neo4j"): - pm.install("neo4j") -from neo4j import ( - AsyncGraphDatabase, - exceptions as neo4jExceptions, - AsyncDriver, - AsyncManagedTransaction, - GraphDatabase, -) from tenacity import ( retry, stop_after_attempt, @@ -29,11 +19,25 @@ from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +try: + from neo4j import ( + AsyncGraphDatabase, + exceptions as neo4jExceptions, + AsyncDriver, + AsyncManagedTransaction, + GraphDatabase, + ) +except ImportError as e: + raise ImportError( + "neo4j library is not installed. Please install it to proceed." + ) from e + config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -141,8 +145,8 @@ class Neo4JStorage(BaseGraphStorage): if self._driver: await self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + pass async def _label_exists(self, label: str) -> bool: """Check if a label exists in the Neo4j database.""" diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 109c5827..f98a8bbb 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,58 +1,8 @@ -""" -NetworkX Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import html import os from dataclasses import dataclass -from typing import Any, cast -import networkx as nx +from typing import Any, cast, final + import numpy as np @@ -65,7 +15,15 @@ from lightrag.base import ( BaseGraphStorage, ) +try: + import networkx as nx +except ImportError as e: + raise ImportError( + "networkx library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class NetworkXStorage(BaseGraphStorage): @staticmethod diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 74268a67..aec4ada4 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -4,17 +4,11 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Union, final import numpy as np -import pipmaster as pm - -if not pm.is_installed("oracledb"): - pm.install("oracledb") - from lightrag.types import KnowledgeGraph -import oracledb from ..base import ( BaseGraphStorage, @@ -24,6 +18,14 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger +try: + import oracledb + +except ImportError as e: + raise ImportError( + "oracledb library is not installed. Please install it to proceed." + ) from e + class OracleDB: def __init__(self, config, **kwargs): @@ -170,6 +172,7 @@ class OracleDB: raise +@final @dataclass class OracleKVStorage(BaseKVStorage): # db instance must be injected before use @@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage): raise NotImplementedError +@final @dataclass class OracleVectorDBStorage(BaseVectorStorage): - # db instance must be injected before use - # db: OracleDB - cosine_better_than_threshold: float = None - def __post_init__(self): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") @@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pass - async def index_done_callback(self): + async def index_done_callback(self) -> None: pass #################### query method ############### @@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage): raise NotImplementedError +@final @dataclass class OracleGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: OracleDB - def __post_init__(self): - """从graphml文件加载图""" self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ @@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage): return embeddings, nodes_ids async def index_done_callback(self) -> None: - """写入graphhml图文件""" - logger.info( - "Node and edge data had been saved into oracle db already, so nothing to do here!" - ) + pass #################### query method ################# async def has_node(self, node_id: str) -> bool: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 77a42ad1..c63547ce 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,26 +4,19 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, final import numpy as np -import pipmaster as pm from lightrag.types import KnowledgeGraph -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - import sys - -import asyncpg from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) -from tqdm.asyncio import tqdm as tqdm_async from ..base import ( BaseGraphStorage, @@ -41,6 +34,15 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +try: + import asyncpg + from tqdm.asyncio import tqdm as tqdm_async + +except ImportError as e: + raise ImportError( + "asyncpg, tqdm_async library is not installed. Please install it to proceed." + ) from e + class PostgreSQLDB: def __init__(self, config, **kwargs): @@ -177,6 +179,7 @@ class PostgreSQLDB: pass +@final @dataclass class PGKVStorage(BaseKVStorage): # db instance must be injected before use @@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into postgresql db!") + pass async def drop(self) -> None: raise NotImplementedError +@final @dataclass class PGVectorStorage(BaseVectorStorage): - # db instance must be injected before use - # db: PostgreSQLDB - cosine_better_than_threshold: float = None - def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) async def index_done_callback(self) -> None: - logger.info("vector data had been saved into postgresql db!") + pass #################### query method ############### async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: @@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage): raise NotImplementedError +@final @dataclass class PGDocStatusStorage(DocStatusStorage): # db instance must be injected before use # db: PostgreSQLDB - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" - keys = ",".join([f"'{_id}'" for _id in data]) + keys = ",".join([f"'{_id}'" for _id in keys]) sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" result = await self.db.query(sql, multirows=True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: - return set(data) + return set(keys) else: existed = set([element["id"] for element in result]) - return set(data) - existed + return set(keys) - existed async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" @@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage): updated_at=result[0]["updated_at"], ) + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + raise NotImplementedError + async def get_status_counts(self) -> Dict[str, int]: """Get counts of documents in each status""" sql = """SELECT status as "status", COUNT(1) as "count" @@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage): """Get all procesed documents""" return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self): - """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" - logger.info("Doc status had been saved into postgresql db!") + async def index_done_callback(self) -> None: + pass async def upsert(self, data: dict[str, dict]): """Update or insert document status @@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage): } await self.db.execute(sql, _data) + async def drop(self) -> None: + raise NotImplementedError + class PGGraphQueryException(Exception): """Exception for the AGE queries.""" @@ -593,11 +595,9 @@ class PGGraphQueryException(Exception): return self.details +@final @dataclass class PGGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: PostgreSQLDB - @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") @@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + pass @staticmethod def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index eb9582e6..1d4a0ca1 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -8,17 +8,20 @@ import hashlib import uuid from ..utils import logger from ..base import BaseVectorStorage -import pipmaster as pm import configparser -if not pm.is_installed("qdrant_client"): - pm.install("qdrant_client") - -from qdrant_client import QdrantClient, models config = configparser.ConfigParser() config.read("config.ini", "utf-8") +try: + from qdrant_client import QdrantClient, models + +except ImportError as e: + raise ImportError( + "qdrant_client library is not installed. Please install it to proceed." + ) from e + def compute_mdhash_id_for_qdrant( content: str, prefix: str = "", style: str = "simple" @@ -48,10 +51,9 @@ def compute_mdhash_id_for_qdrant( raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.") +@final @dataclass class QdrantVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - @staticmethod def create_collection_if_not_exist( client: QdrantClient, collection_name: str, **kwargs diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 27850d81..69a6da2a 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,24 +1,26 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Union, final import numpy as np -import pipmaster as pm - -if not pm.is_installed("pymysql"): - pm.install("pymysql") -if not pm.is_installed("sqlalchemy"): - pm.install("sqlalchemy") from lightrag.types import KnowledgeGraph -from sqlalchemy import create_engine, text + from tqdm import tqdm from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace from ..utils import logger +try: + from sqlalchemy import create_engine, text + +except ImportError as e: + raise ImportError( + "pymysql, sqlalchemy library is not installed. Please install it to proceed." + ) from e + class TiDB: def __init__(self, config, **kwargs): @@ -100,6 +102,7 @@ class TiDB: raise +@final @dataclass class TiDBKVStorage(BaseKVStorage): # db instance must be injected before use @@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage): await self.db.execute(merge_sql, data) return left_data - async def index_done_callback(self): - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into TiDB db!") + async def index_done_callback(self) -> None: + pass async def drop(self) -> None: raise NotImplementedError +@final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - # db instance must be injected before use - # db: TiDB - cosine_better_than_threshold: float = None - def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" @@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + async def index_done_callback(self) -> None: + raise NotImplementedError + +@final @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage): else: return [] + async def index_done_callback(self) -> None: + pass + async def delete_node(self, node_id: str) -> None: raise NotImplementedError From a0844bca2837861b2bdd9670b6ab5970d8d2334d Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 14:45:45 +0100 Subject: [PATCH 07/17] cleaned import --- lightrag/kg/age_impl.py | 6 +++--- lightrag/kg/chroma_impl.py | 2 +- lightrag/kg/faiss_impl.py | 2 +- lightrag/kg/gremlin_impl.py | 2 +- lightrag/kg/mongo_impl.py | 3 +-- lightrag/kg/nano_vector_db_impl.py | 2 +- lightrag/kg/neo4j_impl.py | 3 +-- lightrag/kg/networkx_impl.py | 2 +- lightrag/kg/oracle_impl.py | 2 +- lightrag/kg/postgres_impl.py | 2 +- lightrag/kg/qdrant_impl.py | 6 +++--- lightrag/kg/tidb_impl.py | 2 +- 12 files changed, 16 insertions(+), 18 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 24f70de9..f9499376 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -31,10 +31,10 @@ try: import psycopg from psycopg.rows import namedtuple_row from psycopg_pool import AsyncConnectionPool, PoolTimeout -except ImportError as e: +except ImportError: raise ImportError( - "psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed." - ) from e + "`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`." + ) class AGEQueryException(Exception): diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index f2d2293f..ecac2b62 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -11,7 +11,7 @@ try: from chromadb.config import Settings except ImportError as e: raise ImportError( - "chromadb library is not installed. Please install it to proceed." + "`chromadb` library is not installed. Please install it via pip: `pip install chromadb`." ) from e diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index e2c06afe..0f455f1c 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -20,7 +20,7 @@ try: import faiss except ImportError as e: raise ImportError( - "faiss library is not installed. Please install it to proceed." + "`faiss` library is not installed. Please install it via pip: `pip install faiss`." ) from e diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 4038be23..d95bb00b 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -26,7 +26,7 @@ try: from gremlin_python.driver.protocol import GremlinServerError except ImportError as e: raise ImportError( - "gremlin library is not installed. Please install it to proceed." + "`gremlin` library is not installed. Please install it via pip: `pip install gremlin`." ) from e diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index f44332bf..330fa474 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -26,10 +26,9 @@ try: from pymongo.errors import PyMongoError except ImportError as e: raise ImportError( - "motor, pymongo library is not installed. Please install it to proceed." + "`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`." ) from e - config = configparser.ConfigParser() config.read("config.ini", "utf-8") diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 4ab98fe6..fbd6a06a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -20,7 +20,7 @@ try: from nano_vectordb import NanoVectorDB except ImportError as e: raise ImportError( - "nano-vectordb library is not installed. Please install it to proceed." + "`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`." ) from e diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 8d078af0..83edd299 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -29,10 +29,9 @@ try: ) except ImportError as e: raise ImportError( - "neo4j library is not installed. Please install it to proceed." + "`neo4j` library is not installed. Please install it via pip: `pip install neo4j`." ) from e - config = configparser.ConfigParser() config.read("config.ini", "utf-8") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index f98a8bbb..04bb3bd7 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -19,7 +19,7 @@ try: import networkx as nx except ImportError as e: raise ImportError( - "networkx library is not installed. Please install it to proceed." + "`networkx` library is not installed. Please install it via pip: `pip install networkx`." ) from e diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index aec4ada4..d2d10141 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -23,7 +23,7 @@ try: except ImportError as e: raise ImportError( - "oracledb library is not installed. Please install it to proceed." + "`oracledb` library is not installed. Please install it via pip: `pip install oracledb`." ) from e diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c63547ce..2b7996d6 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -40,7 +40,7 @@ try: except ImportError as e: raise ImportError( - "asyncpg, tqdm_async library is not installed. Please install it to proceed." + "`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`." ) from e diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 1d4a0ca1..124d48d9 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -17,10 +17,10 @@ config.read("config.ini", "utf-8") try: from qdrant_client import QdrantClient, models -except ImportError as e: +except ImportError: raise ImportError( - "qdrant_client library is not installed. Please install it to proceed." - ) from e + "`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`." + ) def compute_mdhash_id_for_qdrant( diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 69a6da2a..003316d3 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -18,7 +18,7 @@ try: except ImportError as e: raise ImportError( - "pymysql, sqlalchemy library is not installed. Please install it to proceed." + "`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`." ) from e From 9a5fbaaa5f67b97bbe28b46c5a4372cd4bf8c0cf Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 14:50:04 +0100 Subject: [PATCH 08/17] removed unused methods --- lightrag/base.py | 5 ----- lightrag/kg/json_doc_status_impl.py | 5 +---- lightrag/kg/postgres_impl.py | 31 +---------------------------- lightrag/lightrag.py | 4 ++-- 4 files changed, 4 insertions(+), 41 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 798e3176..ee13a11f 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -260,8 +260,3 @@ class DocStatusStorage(BaseKVStorage, ABC): @abstractmethod async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: """Get all procesed documents""" - - @abstractmethod - async def update_doc_status(self, data: dict[str, Any]) -> None: - """Updates the status of a document. By default, it calls upsert.""" - await self.upsert(data) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index b96a744c..15fbfcde 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -75,7 +75,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: write_json(self._data, self._file_name) - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: self._data.update(data) await self.index_done_callback() @@ -89,6 +89,3 @@ class JsonDocStatusStorage(DocStatusStorage): async def drop(self) -> None: raise NotImplementedError - - async def update_doc_status(self, data: dict[str, Any]) -> None: - raise NotImplementedError diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 2b7996d6..0dffa7d3 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -429,9 +429,6 @@ class PGVectorStorage(BaseVectorStorage): @final @dataclass class PGDocStatusStorage(DocStatusStorage): - # db instance must be injected before use - # db: PostgreSQLDB - async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" keys = ",".join([f"'{_id}'" for _id in keys]) @@ -516,7 +513,7 @@ class PGDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: pass - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Update or insert document status Args: @@ -547,32 +544,6 @@ class PGDocStatusStorage(DocStatusStorage): ) return data - async def update_doc_status(self, data: dict[str, dict]) -> None: - """ - Updates only the document status, chunk count, and updated timestamp. - - This method ensures that only relevant fields are updated instead of overwriting - the entire document record. If `updated_at` is not provided, the database will - automatically use the current timestamp. - """ - sql = """ - UPDATE LIGHTRAG_DOC_STATUS - SET status = $3, - chunks_count = $4, - updated_at = CURRENT_TIMESTAMP - WHERE workspace = $1 AND id = $2 - """ - for k, v in data.items(): - _data = { - "workspace": self.db.workspace, - "id": k, - "status": v["status"].value, # Convert Enum to string - "chunks_count": v.get( - "chunks_count", -1 - ), # Default to -1 if not provided - } - await self.db.execute(sql, _data) - async def drop(self) -> None: raise NotImplementedError diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 23c3df80..f4e9b770 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -779,7 +779,7 @@ class LightRAG: ] try: await asyncio.gather(*tasks) - await self.doc_status.update_doc_status( + await self.doc_status.upsert( { doc_status_id: { "status": DocStatus.PROCESSED, @@ -796,7 +796,7 @@ class LightRAG: except Exception as e: logger.error(f"Failed to process document {doc_id}: {str(e)}") - await self.doc_status.update_doc_status( + await self.doc_status.upsert( { doc_status_id: { "status": DocStatus.FAILED, From 0a8c94a1e009549cd53db2267771d8881836dedc Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 14:51:24 +0100 Subject: [PATCH 09/17] fix value --- lightrag/kg/json_doc_status_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 15fbfcde..7fccf3c3 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -39,7 +39,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" - counts = {status: 0 for status in DocStatus} + counts = {status.value: 0 for status in DocStatus} for doc in self._data.values(): counts[doc["status"]] += 1 return counts From 0e7aff96bbd70c2a347c4c87ad857b3374e6f41f Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 15:08:50 +0100 Subject: [PATCH 10/17] back to not making breaks --- lightrag/kg/age_impl.py | 9 ++++++++- lightrag/kg/chroma_impl.py | 4 ++++ lightrag/kg/faiss_impl.py | 7 ++++++- lightrag/kg/milvus_impl.py | 12 +++++++----- lightrag/kg/mongo_impl.py | 10 +++++++--- lightrag/kg/nano_vector_db_impl.py | 5 ++++- lightrag/kg/oracle_impl.py | 5 +++++ lightrag/kg/postgres_impl.py | 5 +++++ lightrag/kg/qdrant_impl.py | 5 +++++ lightrag/kg/tidb_impl.py | 7 +++++++ 10 files changed, 58 insertions(+), 11 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index f9499376..b2eda5da 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np - +import pipmaster as pm from lightrag.types import KnowledgeGraph from tenacity import ( @@ -27,6 +27,13 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") + +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + try: import psycopg from psycopg.rows import namedtuple_row diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index ecac2b62..340e7a66 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -5,6 +5,10 @@ import numpy as np from lightrag.base import BaseVectorStorage from lightrag.utils import logger +import pipmaster as pm + +if not pm.is_installed("chromadb"): + pm.install("chromadb") try: from chromadb import HttpClient, PersistentClient diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0f455f1c..dea8aac0 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -5,8 +5,9 @@ from typing import Any, final import json import numpy as np -from tqdm.asyncio import tqdm as tqdm_async + from dataclasses import dataclass +import pipmaster as pm from lightrag.utils import ( logger, @@ -16,8 +17,12 @@ from lightrag.base import ( BaseVectorStorage, ) +if not pm.is_installed("faiss"): + pm.install("faiss") + try: import faiss + from tqdm.asyncio import tqdm as tqdm_async except ImportError as e: raise ImportError( "`faiss` library is not installed. Please install it via pip: `pip install faiss`." diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 1288df07..9aa9c37d 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -6,16 +6,18 @@ from dataclasses import dataclass import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage - +import pipmaster as pm import configparser +if not pm.is_installed("pymilvus"): + pm.install("pymilvus") + try: from pymilvus import MilvusClient -except ImportError: +except ImportError as e: raise ImportError( - "pymilvus library is not installed. Please install it to proceed." - ) - + "`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`." + ) from e config = configparser.ConfigParser() config.read("config.ini", "utf-8") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 330fa474..78b63179 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -18,6 +18,13 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +import pipmaster as pm + +if not pm.is_installed("pymongo"): + pm.install("pymongo") + +if not pm.is_installed("motor"): + pm.install("motor") try: from motor.motor_asyncio import AsyncIOMotorClient @@ -214,9 +221,6 @@ class MongoDocStatusStorage(DocStatusStorage): # Implement the method here pass - async def update_doc_status(self, data: dict[str, Any]) -> None: - raise NotImplementedError - @final @dataclass diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index fbd6a06a..30b766e0 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -11,11 +11,14 @@ from lightrag.utils import ( logger, compute_mdhash_id, ) - +import pipmaster as pm from lightrag.base import ( BaseVectorStorage, ) +if not pm.is_installed("nano-vectordb"): + pm.install("nano-vectordb") + try: from nano_vectordb import NanoVectorDB except ImportError as e: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index d2d10141..560ffb88 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -18,6 +18,11 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger +import pipmaster as pm + +if not pm.is_installed("oracledb"): + pm.install("oracledb") + try: import oracledb diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 0dffa7d3..0ec30644 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -34,6 +34,11 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +import pipmaster as pm + +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + try: import asyncpg from tqdm.asyncio import tqdm as tqdm_async diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 124d48d9..b4190f1f 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -14,6 +14,11 @@ import configparser config = configparser.ConfigParser() config.read("config.ini", "utf-8") +import pipmaster as pm + +if not pm.is_installed("qdrant_client"): + pm.install("qdrant_client") + try: from qdrant_client import QdrantClient, models diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 003316d3..4f7c891c 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -13,6 +13,13 @@ from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace from ..utils import logger +import pipmaster as pm + +if not pm.is_installed("pymysql"): + pm.install("pymysql") +if not pm.is_installed("sqlalchemy"): + pm.install("sqlalchemy") + try: from sqlalchemy import create_engine, text From 0c21442ca42225d0704a314898fbf26baa5fc113 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 15:20:46 +0100 Subject: [PATCH 11/17] fixed default init --- lightrag/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/base.py b/lightrag/base.py index ee13a11f..c44d6af8 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,7 +92,7 @@ class StorageNameSpace(ABC): @dataclass class BaseVectorStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc - cosine_better_than_threshold: float + cosine_better_than_threshold: float = field(default=0.2) meta_fields: set[str] = field(default_factory=set) @abstractmethod From 2bf238396e5af40caa02cb7674b5e761e5e565f7 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 15:52:59 +0100 Subject: [PATCH 12/17] updated wrong status --- lightrag/base.py | 23 +++++-------------- lightrag/kg/json_doc_status_impl.py | 34 +++++++---------------------- lightrag/kg/mongo_impl.py | 16 -------------- lightrag/kg/postgres_impl.py | 18 +-------------- lightrag/kg/redis_impl.py | 4 ++-- 5 files changed, 17 insertions(+), 78 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index c44d6af8..98bdb606 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,9 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import StrEnum import os from dataclasses import dataclass, field -from enum import Enum from typing import ( Any, Literal, @@ -203,7 +203,7 @@ class BaseGraphStorage(StorageNameSpace, ABC): """Retrieve a subgraph of the knowledge graph starting from a given node.""" -class DocStatus(str, Enum): +class DocStatus(StrEnum): """Document processing status enum""" PENDING = "pending" @@ -245,18 +245,7 @@ class DocStatusStorage(BaseKVStorage, ABC): """Get counts of documents in each status""" @abstractmethod - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - - @abstractmethod - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - raise NotImplementedError - - @abstractmethod - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - - @abstractmethod - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 7fccf3c3..33df6d43 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -44,33 +44,15 @@ class JsonDocStatusStorage(DocStatusStorage): counts[doc["status"]] += 1 return counts - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.FAILED - } - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PENDING - } - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSED - } - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSING - } + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == status.value + } async def index_done_callback(self) -> None: write_json(self._data, self._file_name) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 219ec313..abc0aeb5 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -201,22 +201,6 @@ class MongoDocStatusStorage(DocStatusStorage): for doc in result } - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self) -> None: # Implement the method here pass diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 9bd17ec5..33b4259f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -484,7 +484,7 @@ class PGDocStatusStorage(DocStatusStorage): ) -> Dict[str, DocProcessingStatus]: """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" - params = {"workspace": self.db.workspace, "status": status} + params = {"workspace": self.db.workspace, "status": status.value} result = await self.db.query(sql, params, True) return { element["id"]: DocProcessingStatus( @@ -499,22 +499,6 @@ class PGDocStatusStorage(DocStatusStorage): for element in result } - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self) -> None: pass diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 71e39c5c..98258741 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -18,7 +18,7 @@ import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") - +@final @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): From abad9f235c8d18ed440bf9d45d860a0600540f35 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 15:54:54 +0100 Subject: [PATCH 13/17] make value on str enum --- lightrag/kg/json_doc_status_impl.py | 8 ++++---- lightrag/kg/redis_impl.py | 1 + lightrag/lightrag.py | 8 ++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 33df6d43..6c667891 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -49,10 +49,10 @@ class JsonDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == status.value - } + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == status.value + } async def index_done_callback(self) -> None: write_json(self._data, self._file_name) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 98258741..8dae1e77 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -18,6 +18,7 @@ import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @final @dataclass class RedisKVStorage(BaseKVStorage): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 82209504..3ad94da0 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -674,7 +674,7 @@ class LightRAG: "content": content, "content_summary": self._get_content_summary(content), "content_length": len(content), - "status": DocStatus.PENDING, + "status": DocStatus.PENDING.value, "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), } @@ -745,7 +745,7 @@ class LightRAG: await self.doc_status.upsert( { doc_status_id: { - "status": DocStatus.PROCESSING, + "status": DocStatus.PROCESSING.value, "updated_at": datetime.now().isoformat(), "content": status_doc.content, "content_summary": status_doc.content_summary, @@ -782,7 +782,7 @@ class LightRAG: await self.doc_status.upsert( { doc_status_id: { - "status": DocStatus.PROCESSED, + "status": DocStatus.PROCESSED.value, "chunks_count": len(chunks), "content": status_doc.content, "content_summary": status_doc.content_summary, @@ -799,7 +799,7 @@ class LightRAG: await self.doc_status.upsert( { doc_status_id: { - "status": DocStatus.FAILED, + "status": DocStatus.FAILED.value, "error": str(e), "content": status_doc.content, "content_summary": status_doc.content_summary, From 2b2c81a7224f55a27ab056d1192b64b9420251a7 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 16:04:07 +0100 Subject: [PATCH 14/17] added some comments --- lightrag/kg/age_impl.py | 1 + lightrag/kg/chroma_impl.py | 2 -- lightrag/kg/faiss_impl.py | 8 -------- lightrag/kg/gremlin_impl.py | 1 + lightrag/kg/milvus_impl.py | 3 +-- lightrag/kg/mongo_impl.py | 7 ++++--- lightrag/kg/nano_vector_db_impl.py | 1 - lightrag/kg/neo4j_impl.py | 5 +++++ lightrag/kg/oracle_impl.py | 23 ++++++++++------------- lightrag/kg/postgres_impl.py | 12 +++++++----- lightrag/kg/qdrant_impl.py | 3 +-- lightrag/kg/redis_impl.py | 1 + lightrag/kg/tidb_impl.py | 7 ++++--- 13 files changed, 35 insertions(+), 39 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index b2eda5da..243a110b 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -634,4 +634,5 @@ class AGEStorage(BaseGraphStorage): raise NotImplementedError async def index_done_callback(self) -> None: + # AGES handles persistence automatically pass diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 340e7a66..62a9b601 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -198,9 +198,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index ec4e7776..2b67e2fa 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -223,10 +223,6 @@ class FaissVectorDBStorage(BaseVectorStorage): ) async def delete_entity(self, entity_name: str) -> None: - """ - Delete a single entity by computing its hashed ID - the same way your code does it with `compute_mdhash_id`. - """ entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") await self.delete([entity_id]) @@ -247,11 +243,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Deleted {len(relations)} relations for {entity_name}") async def index_done_callback(self) -> None: - """ - Called after indexing is done (save Faiss index + metadata). - """ self._save_faiss_index() - logger.info("Faiss index saved successfully.") # -------------------------------------------------------------------------------- # Internal helper methods diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index d95bb00b..40a9f007 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -88,6 +88,7 @@ class GremlinStorage(BaseGraphStorage): self._driver.close() async def index_done_callback(self) -> None: + # Gremlin handles persistence automatically pass @staticmethod diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 9aa9c37d..3e8f1ba5 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -130,12 +130,11 @@ class MilvusVectorDBStorage(BaseVectorStorage): ] async def index_done_callback(self) -> None: + # Milvus handles persistence automatically pass async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index abc0aeb5..4eb968cf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -114,6 +114,7 @@ class MongoKVStorage(BaseKVStorage): return None async def index_done_callback(self) -> None: + # Mongo handles persistence automatically pass async def drop(self) -> None: @@ -202,7 +203,7 @@ class MongoDocStatusStorage(DocStatusStorage): } async def index_done_callback(self) -> None: - # Implement the method here + # Mongo handles persistence automatically pass @@ -771,6 +772,7 @@ class MongoGraphStorage(BaseGraphStorage): return result async def index_done_callback(self) -> None: + # Mongo handles persistence automatically pass @@ -919,14 +921,13 @@ class MongoVectorDBStorage(BaseVectorStorage): ] async def index_done_callback(self) -> None: + # Mongo handles persistence automatically pass async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index c631b086..16955d8a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -167,6 +167,5 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: - # Protect file write operation async with self._save_lock: self._client.save() diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 83edd299..64721a49 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -18,7 +18,11 @@ from tenacity import ( from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +import pipmaster as pm +if not pm.is_installed("neo4j"): + pm.install("neo4j") + try: from neo4j import ( AsyncGraphDatabase, @@ -145,6 +149,7 @@ class Neo4JStorage(BaseGraphStorage): await self._driver.close() async def index_done_callback(self) -> None: + # Noe4J handles persistence automatically pass async def _label_exists(self, label: str) -> bool: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 560ffb88..1614543a 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -317,11 +317,8 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into oracle db!") + # Oracle handles persistence automatically + pass async def drop(self) -> None: raise NotImplementedError @@ -339,12 +336,6 @@ class OracleVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - pass - - async def index_done_callback(self) -> None: - pass - #################### query method ############### async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) @@ -366,12 +357,17 @@ class OracleVectorDBStorage(BaseVectorStorage): # print("vector search result:",results) return results + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + raise NotImplementedError + + async def index_done_callback(self) -> None: + # Oracles handles persistence automatically + pass + async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError @@ -476,6 +472,7 @@ class OracleGraphStorage(BaseGraphStorage): return embeddings, nodes_ids async def index_done_callback(self) -> None: + # Oracles handles persistence automatically pass #################### query method ################# diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 33b4259f..193af263 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -298,6 +298,7 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: + # PG handles persistence automatically pass async def drop(self) -> None: @@ -404,9 +405,6 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) - async def index_done_callback(self) -> None: - pass - #################### query method ############### async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) @@ -422,12 +420,14 @@ class PGVectorStorage(BaseVectorStorage): results = await self.db.query(sql, params=params, multirows=True) return results + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError @@ -500,6 +500,7 @@ class PGDocStatusStorage(DocStatusStorage): } async def index_done_callback(self) -> None: + # PG handles persistence automatically pass async def upsert(self, data: dict[str, dict[str, Any]]) -> None: @@ -569,6 +570,7 @@ class PGGraphStorage(BaseGraphStorage): } async def index_done_callback(self) -> None: + # PG handles persistence automatically pass @staticmethod diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index b4190f1f..0610346f 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -153,12 +153,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] async def index_done_callback(self) -> None: + # Qdrant handles persistence automatically pass async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 8dae1e77..2d5c94ce 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -64,4 +64,5 @@ class RedisKVStorage(BaseKVStorage): await self._redis.delete(*keys) async def index_done_callback(self) -> None: + # Redis handles persistence automatically pass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 4f7c891c..6dbfb934 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -211,6 +211,7 @@ class TiDBKVStorage(BaseKVStorage): return left_data async def index_done_callback(self) -> None: + # Ti handles persistence automatically pass async def drop(self) -> None: @@ -339,15 +340,14 @@ class TiDBVectorDBStorage(BaseVectorStorage): return await self.db.query(SQL, params, multirows=True) async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" raise NotImplementedError async def index_done_callback(self) -> None: - raise NotImplementedError + # Ti handles persistence automatically + pass @final @@ -489,6 +489,7 @@ class TiDBGraphStorage(BaseGraphStorage): return [] async def index_done_callback(self) -> None: + # Ti handles persistence automatically pass async def delete_node(self, node_id: str) -> None: From 49bea486a7add5edba03b8a1ae631b673bb4e155 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 16:04:35 +0100 Subject: [PATCH 15/17] cleaned code --- lightrag/kg/neo4j_impl.py | 2 +- lightrag/kg/oracle_impl.py | 2 +- lightrag/kg/postgres_impl.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 64721a49..5ffbf2bc 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -22,7 +22,7 @@ import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") - + try: from neo4j import ( AsyncGraphDatabase, diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 1614543a..c9d8d1b5 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -363,7 +363,7 @@ class OracleVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: # Oracles handles persistence automatically pass - + async def delete_entity(self, entity_name: str) -> None: raise NotImplementedError diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 193af263..5f845894 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -423,7 +423,7 @@ class PGVectorStorage(BaseVectorStorage): async def index_done_callback(self) -> None: # PG handles persistence automatically pass - + async def delete_entity(self, entity_name: str) -> None: raise NotImplementedError From 0b16718f9f150f8f8cc218506cb08c3117833f6a Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 16:21:28 +0100 Subject: [PATCH 16/17] add missing final --- lightrag/kg/json_kv_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 779c52a9..658e1239 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Any +from typing import Any, final from lightrag.base import ( BaseKVStorage, @@ -13,6 +13,7 @@ from lightrag.utils import ( ) +@final @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): From 87a13fd3ea09720b9f16e645d2c4bc5e3228b0e7 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 16:22:28 +0100 Subject: [PATCH 17/17] cleaned code --- lightrag/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 98bdb606..8b98d3aa 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -140,7 +140,6 @@ class BaseKVStorage(StorageNameSpace, ABC): @dataclass class BaseGraphStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc - """Check if a node exists in the graph.""" @abstractmethod async def has_node(self, node_id: str) -> bool: @@ -204,7 +203,7 @@ class BaseGraphStorage(StorageNameSpace, ABC): class DocStatus(StrEnum): - """Document processing status enum""" + """Document processing status""" PENDING = "pending" PROCESSING = "processing"