diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index 913361b3..8f40690e 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -37,20 +37,22 @@ async def main(): llm_model_max_token_size=32768, enable_llm_cache_for_entity_extract=True, embedding_func=EmbeddingFunc( - embedding_dim=768, + embedding_dim=1024, max_token_size=8192, func=lambda texts: ollama_embedding( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" + texts, embed_model="bge-m3", host="http://localhost:11434" ), ), kv_storage="PGKVStorage", doc_status_storage="PGDocStatusStorage", graph_storage="PGGraphStorage", vector_storage="PGVectorStorage", + auto_manage_storages_states=False, ) # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + await rag.initialize_storages() with open(f"{ROOT_DIR}/book.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 2a78af9b..04f71139 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.11" +__version__ = "1.2.1" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index cbbd98c7..c91d23f0 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -254,6 +254,8 @@ class PGKVStorage(BaseKVStorage): db: PostgreSQLDB = field(default=None) def __post_init__(self): + namespace_prefix = self.global_config.get("namespace_prefix") + self.base_namespace = self.namespace.replace(namespace_prefix, "") self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): @@ -269,7 +271,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data by id.""" - sql = SQL_TEMPLATES["get_by_id_" + self.namespace] + sql = SQL_TEMPLATES["get_by_id_" + self.base_namespace] params = {"workspace": self.db.workspace, "id": id} if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) @@ -283,7 +285,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" - sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + sql = SQL_TEMPLATES["get_by_mode_id_" + self.base_namespace] params = {"workspace": self.db.workspace, mode: mode, "id": id} if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) @@ -297,7 +299,7 @@ class PGKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + sql = SQL_TEMPLATES["get_by_ids_" + self.base_namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} @@ -318,7 +320,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + SQL = SQL_TEMPLATES["get_by_status_" + self.base_namespace] params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) @@ -391,6 +393,8 @@ class PGVectorStorage(BaseVectorStorage): def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] + namespace_prefix = self.global_config.get("namespace_prefix") + self.base_namespace = self.namespace.replace(namespace_prefix, "") config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: @@ -493,7 +497,9 @@ class PGVectorStorage(BaseVectorStorage): embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + sql = SQL_TEMPLATES[self.base_namespace].format( + embedding_string=embedding_string + ) params = { "workspace": self.db.workspace, "better_than_threshold": self.cosine_better_than_threshold,