Merge branch 'main' into improve-property-tooltip
This commit is contained in:
@@ -127,6 +127,30 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace, ABC):
|
||||
|
@@ -156,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
@@ -269,3 +271,67 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query the collection for a single vector by ID
|
||||
result = self._collection.get(
|
||||
ids=[id], include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return None
|
||||
|
||||
# Format the result to match the expected structure
|
||||
return {
|
||||
"id": result["ids"][0],
|
||||
"vector": result["embeddings"][0],
|
||||
"content": result["documents"][0],
|
||||
**result["metadatas"][0],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query the collection for multiple vectors by IDs
|
||||
result = self._collection.get(
|
||||
ids=ids, include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return []
|
||||
|
||||
# Format the results to match the expected structure
|
||||
return [
|
||||
{
|
||||
"id": result["ids"][i],
|
||||
"vector": result["embeddings"][i],
|
||||
"content": result["documents"][i],
|
||||
**result["metadatas"][i],
|
||||
}
|
||||
for i in range(len(result["ids"]))
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
@@ -171,7 +171,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
@@ -392,3 +394,46 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
# Find the Faiss internal ID for the custom ID
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is None:
|
||||
return None
|
||||
|
||||
# Get the metadata for the found ID
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
return {**metadata, "id": metadata.get("__id__")}
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for id in ids:
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is not None:
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if metadata:
|
||||
results.append({**metadata, "id": metadata.get("__id__")})
|
||||
|
||||
return results
|
||||
|
@@ -101,7 +101,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
@@ -231,3 +233,57 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query Milvus for a specific ID
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=f'id == "{id}"',
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
if not result or len(result) == 0:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Prepare the ID filter expression
|
||||
id_list = '", "'.join(ids)
|
||||
filter_expr = f'id in ["{id_list}"]'
|
||||
|
||||
# Query Milvus with the filter
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=filter_expr,
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
return result or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
@@ -938,7 +938,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
return list_data
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func([query])
|
||||
@@ -1071,6 +1073,59 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Search for the specific ID in MongoDB
|
||||
result = await self._data.find_one({"_id": id})
|
||||
if result:
|
||||
# Format the result to include id field expected by API
|
||||
result_dict = dict(result)
|
||||
if "_id" in result_dict and "id" not in result_dict:
|
||||
result_dict["id"] = result_dict["_id"]
|
||||
return result_dict
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query MongoDB for multiple IDs
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
results = await cursor.to_list(length=None)
|
||||
|
||||
# Format results to include id field expected by API
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
result_dict = dict(result)
|
||||
if "_id" in result_dict and "id" not in result_dict:
|
||||
result_dict["id"] = result_dict["_id"]
|
||||
formatted_results.append(result_dict)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -120,7 +120,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
@@ -256,3 +258,33 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
client = await self._get_client()
|
||||
result = client.get([id])
|
||||
if result:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
client = await self._get_client()
|
||||
return client.get(ids)
|
||||
|
@@ -417,7 +417,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
self.db = None
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
# 转换精度
|
||||
@@ -529,6 +531,80 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Determine the table name based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
||||
return None
|
||||
|
||||
# Create the appropriate ID field name based on namespace
|
||||
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
||||
if "CHUNKS" in table_name:
|
||||
id_field = "chunk_id"
|
||||
|
||||
# Prepare and execute the query
|
||||
query = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {id_field} = :id AND workspace = :workspace
|
||||
"""
|
||||
params = {"id": id, "workspace": self.db.workspace}
|
||||
|
||||
result = await self.db.query(query, params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Determine the table name based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
||||
return []
|
||||
|
||||
# Create the appropriate ID field name based on namespace
|
||||
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
||||
if "CHUNKS" in table_name:
|
||||
id_field = "chunk_id"
|
||||
|
||||
# Format the list of IDs for SQL IN clause
|
||||
ids_list = ", ".join([f"'{id}'" for id in ids])
|
||||
|
||||
# Prepare and execute the query
|
||||
query = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
|
||||
"""
|
||||
params = {"workspace": self.db.workspace}
|
||||
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
return results or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -621,6 +621,60 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during prefix search for '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
||||
return None
|
||||
|
||||
query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id=$2"
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
|
||||
try:
|
||||
result = await self.db.query(query, params)
|
||||
if result:
|
||||
return dict(result)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
||||
return []
|
||||
|
||||
ids_str = ",".join([f"'{id}'" for id in ids])
|
||||
query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
||||
params = {"workspace": self.db.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
return [dict(record) for record in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -123,7 +123,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
|
@@ -306,7 +306,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search from tidb vector"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
@@ -463,6 +465,100 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = """
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id = :entity_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"entity_id": id, "workspace": self.db.workspace}
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = """
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id = :relation_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"relation_id": id, "workspace": self.db.workspace}
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = """
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id = :chunk_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"chunk_id": id, "workspace": self.db.workspace}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for get_by_id"
|
||||
)
|
||||
return None
|
||||
|
||||
result = await self.db.query(sql_template, params=params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Format IDs for SQL IN clause
|
||||
ids_str = ", ".join([f"'{id}'" for id in ids])
|
||||
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = f"""
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = f"""
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = f"""
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for get_by_ids"
|
||||
)
|
||||
return []
|
||||
|
||||
params = {"workspace": self.db.workspace}
|
||||
results = await self.db.query(sql_template, params=params, multirows=True)
|
||||
return results if results else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -1710,19 +1710,7 @@ class LightRAG:
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity
|
||||
|
||||
Args:
|
||||
entity_name: Entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing entity information, including:
|
||||
- entity_name: Entity name
|
||||
- source_id: Source document ID
|
||||
- graph_data: Complete node data from the graph database
|
||||
- vector_data: (optional) Data from the vector database
|
||||
"""
|
||||
"""Get detailed information of an entity"""
|
||||
|
||||
# Get information from the graph
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
@@ -1737,29 +1725,15 @@ class LightRAG:
|
||||
# Optional: Get vector database information
|
||||
if include_vector_data:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
vector_data = self.entities_vdb._client.get([entity_id])
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
vector_data = await self.entities_vdb.get_by_id(entity_id)
|
||||
result["vector_data"] = vector_data
|
||||
|
||||
return result
|
||||
|
||||
async def get_relation_info(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of a relationship
|
||||
|
||||
Args:
|
||||
src_entity: Source entity name (no need for quotes)
|
||||
tgt_entity: Target entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing relationship information, including:
|
||||
- src_entity: Source entity name
|
||||
- tgt_entity: Target entity name
|
||||
- source_id: Source document ID
|
||||
- graph_data: Complete edge data from the graph database
|
||||
- vector_data: (optional) Data from the vector database
|
||||
"""
|
||||
"""Get detailed information of a relationship"""
|
||||
|
||||
# Get information from the graph
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
@@ -1777,8 +1751,8 @@ class LightRAG:
|
||||
# Optional: Get vector database information
|
||||
if include_vector_data:
|
||||
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
||||
vector_data = self.relationships_vdb._client.get([rel_id])
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
vector_data = await self.relationships_vdb.get_by_id(rel_id)
|
||||
result["vector_data"] = vector_data
|
||||
|
||||
return result
|
||||
|
||||
|
Reference in New Issue
Block a user