diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index cbbd98c7..c91d23f0 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -254,6 +254,8 @@ class PGKVStorage(BaseKVStorage): db: PostgreSQLDB = field(default=None) def __post_init__(self): + namespace_prefix = self.global_config.get("namespace_prefix") + self.base_namespace = self.namespace.replace(namespace_prefix, "") self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): @@ -269,7 +271,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data by id.""" - sql = SQL_TEMPLATES["get_by_id_" + self.namespace] + sql = SQL_TEMPLATES["get_by_id_" + self.base_namespace] params = {"workspace": self.db.workspace, "id": id} if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) @@ -283,7 +285,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" - sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + sql = SQL_TEMPLATES["get_by_mode_id_" + self.base_namespace] params = {"workspace": self.db.workspace, mode: mode, "id": id} if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) @@ -297,7 +299,7 @@ class PGKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + sql = SQL_TEMPLATES["get_by_ids_" + self.base_namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} @@ -318,7 +320,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + SQL = SQL_TEMPLATES["get_by_status_" + self.base_namespace] params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) @@ -391,6 +393,8 @@ class PGVectorStorage(BaseVectorStorage): def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] + namespace_prefix = self.global_config.get("namespace_prefix") + self.base_namespace = self.namespace.replace(namespace_prefix, "") config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: @@ -493,7 +497,9 @@ class PGVectorStorage(BaseVectorStorage): embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + sql = SQL_TEMPLATES[self.base_namespace].format( + embedding_string=embedding_string + ) params = { "workspace": self.db.workspace, "better_than_threshold": self.cosine_better_than_threshold,