From c26cb3a9ea6747e745805ccc2e48ec5e6f52b0fa Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Mar 2025 16:05:04 +0800 Subject: [PATCH] fix merge bugs --- lightrag/base.py | 24 ++++++++ lightrag/kg/chroma_impl.py | 64 ++++++++++++++++++++ lightrag/kg/faiss_impl.py | 43 ++++++++++++++ lightrag/kg/milvus_impl.py | 54 +++++++++++++++++ lightrag/kg/mongo_impl.py | 53 +++++++++++++++++ lightrag/kg/nano_vector_db_impl.py | 30 ++++++++++ lightrag/kg/oracle_impl.py | 74 +++++++++++++++++++++++ lightrag/kg/postgres_impl.py | 54 +++++++++++++++++ lightrag/kg/tidb_impl.py | 94 ++++++++++++++++++++++++++++++ lightrag/lightrag.py | 38 ++---------- 10 files changed, 496 insertions(+), 32 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index c84c7c62..86566787 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -127,6 +127,30 @@ class BaseVectorStorage(StorageNameSpace, ABC): async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity.""" + @abstractmethod + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + pass + + @abstractmethod + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + pass + @dataclass class BaseKVStorage(StorageNameSpace, ABC): diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 6b521180..f668c87a 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -269,3 +269,67 @@ class ChromaVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error during prefix search in ChromaDB: {str(e)}") raise + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Query the collection for a single vector by ID + result = self._collection.get( + ids=[id], include=["metadatas", "embeddings", "documents"] + ) + + if not result or not result["ids"] or len(result["ids"]) == 0: + return None + + # Format the result to match the expected structure + return { + "id": result["ids"][0], + "vector": result["embeddings"][0], + "content": result["documents"][0], + **result["metadatas"][0], + } + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Query the collection for multiple vectors by IDs + result = self._collection.get( + ids=ids, include=["metadatas", "embeddings", "documents"] + ) + + if not result or not result["ids"] or len(result["ids"]) == 0: + return [] + + # Format the results to match the expected structure + return [ + { + "id": result["ids"][i], + "vector": result["embeddings"][i], + "content": result["documents"][i], + **result["metadatas"][i], + } + for i in range(len(result["ids"])) + ] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index ab036e6f..a5716e9c 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -392,3 +392,46 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'") return matching_records + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + # Find the Faiss internal ID for the custom ID + fid = self._find_faiss_id_by_custom_id(id) + if fid is None: + return None + + # Get the metadata for the found ID + metadata = self._id_to_meta.get(fid, {}) + if not metadata: + return None + + return {**metadata, "id": metadata.get("__id__")} + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + results = [] + for id in ids: + fid = self._find_faiss_id_by_custom_id(id) + if fid is not None: + metadata = self._id_to_meta.get(fid, {}) + if metadata: + results.append({**metadata, "id": metadata.get("__id__")}) + + return results diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f3a6fcc4..4fb5f012 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -231,3 +231,57 @@ class MilvusVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error searching for records with prefix '{prefix}': {e}") return [] + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Query Milvus for a specific ID + result = self._client.query( + collection_name=self.namespace, + filter=f'id == "{id}"', + output_fields=list(self.meta_fields) + ["id"], + ) + + if not result or len(result) == 0: + return None + + return result[0] + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Prepare the ID filter expression + id_list = '", "'.join(ids) + filter_expr = f'id in ["{id_list}"]' + + # Query Milvus with the filter + result = self._client.query( + collection_name=self.namespace, + filter=filter_expr, + output_fields=list(self.meta_fields) + ["id"], + ) + + return result or [] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index f2ab6ae0..da4dc32c 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1071,6 +1071,59 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Search for the specific ID in MongoDB + result = await self._data.find_one({"_id": id}) + if result: + # Format the result to include id field expected by API + result_dict = dict(result) + if "_id" in result_dict and "id" not in result_dict: + result_dict["id"] = result_dict["_id"] + return result_dict + return None + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Query MongoDB for multiple IDs + cursor = self._data.find({"_id": {"$in": ids}}) + results = await cursor.to_list(length=None) + + # Format results to include id field expected by API + formatted_results = [] + for result in results: + result_dict = dict(result) + if "_id" in result_dict and "id" not in result_dict: + result_dict["id"] = result_dict["_id"] + formatted_results.append(result_dict) + + return formatted_results + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): collection_names = await db.list_collection_names() diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 07ccd566..ac010f16 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -256,3 +256,33 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'") return matching_records + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + client = await self._get_client() + result = client.get([id]) + if result: + return result[0] + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + client = await self._get_client() + return client.get(ids) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index eda3ca63..32790f4f 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -529,6 +529,80 @@ class OracleVectorDBStorage(BaseVectorStorage): logger.error(f"Error searching records with prefix '{prefix}': {e}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Determine the table name based on namespace + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for ID lookup: {self.namespace}") + return None + + # Create the appropriate ID field name based on namespace + id_field = "entity_id" if "NODES" in table_name else "relation_id" + if "CHUNKS" in table_name: + id_field = "chunk_id" + + # Prepare and execute the query + query = f""" + SELECT * FROM {table_name} + WHERE {id_field} = :id AND workspace = :workspace + """ + params = {"id": id, "workspace": self.db.workspace} + + result = await self.db.query(query, params) + return result + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Determine the table name based on namespace + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") + return [] + + # Create the appropriate ID field name based on namespace + id_field = "entity_id" if "NODES" in table_name else "relation_id" + if "CHUNKS" in table_name: + id_field = "chunk_id" + + # Format the list of IDs for SQL IN clause + ids_list = ", ".join([f"'{id}'" for id in ids]) + + # Prepare and execute the query + query = f""" + SELECT * FROM {table_name} + WHERE {id_field} IN ({ids_list}) AND workspace = :workspace + """ + params = {"workspace": self.db.workspace} + + results = await self.db.query(query, params, multirows=True) + return results or [] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + @final @dataclass diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 1d525bdb..49d462f6 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -621,6 +621,60 @@ class PGVectorStorage(BaseVectorStorage): logger.error(f"Error during prefix search for '{prefix}': {e}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for ID lookup: {self.namespace}") + return None + + query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id=$2" + params = {"workspace": self.db.workspace, "id": id} + + try: + result = await self.db.query(query, params) + if result: + return dict(result) + return None + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") + return [] + + ids_str = ",".join([f"'{id}'" for id in ids]) + query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + params = {"workspace": self.db.workspace} + + try: + results = await self.db.query(query, params, multirows=True) + return [dict(record) for record in results] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + @final @dataclass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 7af9b48a..c4485df6 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -463,6 +463,100 @@ class TiDBVectorDBStorage(BaseVectorStorage): logger.error(f"Error searching records with prefix '{prefix}': {e}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Determine which table to query based on namespace + if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: + sql_template = """ + SELECT entity_id as id, name as entity_name, entity_type, description, content + FROM LIGHTRAG_GRAPH_NODES + WHERE entity_id = :entity_id AND workspace = :workspace + """ + params = {"entity_id": id, "workspace": self.db.workspace} + elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: + sql_template = """ + SELECT relation_id as id, source_name as src_id, target_name as tgt_id, + keywords, description, content + FROM LIGHTRAG_GRAPH_EDGES + WHERE relation_id = :relation_id AND workspace = :workspace + """ + params = {"relation_id": id, "workspace": self.db.workspace} + elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: + sql_template = """ + SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE chunk_id = :chunk_id AND workspace = :workspace + """ + params = {"chunk_id": id, "workspace": self.db.workspace} + else: + logger.warning( + f"Namespace {self.namespace} not supported for get_by_id" + ) + return None + + result = await self.db.query(sql_template, params=params) + return result + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Format IDs for SQL IN clause + ids_str = ", ".join([f"'{id}'" for id in ids]) + + # Determine which table to query based on namespace + if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: + sql_template = f""" + SELECT entity_id as id, name as entity_name, entity_type, description, content + FROM LIGHTRAG_GRAPH_NODES + WHERE entity_id IN ({ids_str}) AND workspace = :workspace + """ + elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: + sql_template = f""" + SELECT relation_id as id, source_name as src_id, target_name as tgt_id, + keywords, description, content + FROM LIGHTRAG_GRAPH_EDGES + WHERE relation_id IN ({ids_str}) AND workspace = :workspace + """ + elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: + sql_template = f""" + SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE chunk_id IN ({ids_str}) AND workspace = :workspace + """ + else: + logger.warning( + f"Namespace {self.namespace} not supported for get_by_ids" + ) + return [] + + params = {"workspace": self.db.workspace} + results = await self.db.query(sql_template, params=params, multirows=True) + return results if results else [] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + @final @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3a7d340a..8ab8ece6 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1756,19 +1756,7 @@ class LightRAG: async def get_entity_info( self, entity_name: str, include_vector_data: bool = False ) -> dict[str, str | None | dict[str, str]]: - """Get detailed information of an entity - - Args: - entity_name: Entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - - Returns: - dict: A dictionary containing entity information, including: - - entity_name: Entity name - - source_id: Source document ID - - graph_data: Complete node data from the graph database - - vector_data: (optional) Data from the vector database - """ + """Get detailed information of an entity""" # Get information from the graph node_data = await self.chunk_entity_relation_graph.get_node(entity_name) @@ -1783,29 +1771,15 @@ class LightRAG: # Optional: Get vector database information if include_vector_data: entity_id = compute_mdhash_id(entity_name, prefix="ent-") - vector_data = self.entities_vdb._client.get([entity_id]) - result["vector_data"] = vector_data[0] if vector_data else None + vector_data = await self.entities_vdb.get_by_id(entity_id) + result["vector_data"] = vector_data return result async def get_relation_info( self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ) -> dict[str, str | None | dict[str, str]]: - """Get detailed information of a relationship - - Args: - src_entity: Source entity name (no need for quotes) - tgt_entity: Target entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - - Returns: - dict: A dictionary containing relationship information, including: - - src_entity: Source entity name - - tgt_entity: Target entity name - - source_id: Source document ID - - graph_data: Complete edge data from the graph database - - vector_data: (optional) Data from the vector database - """ + """Get detailed information of a relationship""" # Get information from the graph edge_data = await self.chunk_entity_relation_graph.get_edge( @@ -1823,8 +1797,8 @@ class LightRAG: # Optional: Get vector database information if include_vector_data: rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-") - vector_data = self.relationships_vdb._client.get([rel_id]) - result["vector_data"] = vector_data[0] if vector_data else None + vector_data = await self.relationships_vdb.get_by_id(rel_id) + result["vector_data"] = vector_data return result