From 33caba3e127ce4827c853ea86e4dfdfa98f92c25 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:32:40 +0800 Subject: [PATCH] fix pre commit --- examples/lightrag_api_oracle_demo..py | 79 ++--- examples/lightrag_oracle_demo.py | 49 +-- lightrag/base.py | 2 + lightrag/kg/oracle_impl.py | 447 +++++++++++++------------- lightrag/lightrag.py | 52 +-- lightrag/operate.py | 2 +- requirements.txt | 24 +- 7 files changed, 345 insertions(+), 310 deletions(-) diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index d63b4588..3bfae452 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -1,10 +1,10 @@ - from fastapi import FastAPI, HTTPException, File, UploadFile from contextlib import asynccontextmanager from pydantic import BaseModel from typing import Optional -import sys, os +import sys +import os from pathlib import Path import asyncio @@ -13,7 +13,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np -from datetime import datetime from lightrag.kg.oracle_impl import OracleDB @@ -24,8 +23,6 @@ script_directory = Path(__file__).resolve().parent.parent sys.path.append(os.path.abspath(script_directory)) - - # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() @@ -51,6 +48,7 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -80,8 +78,8 @@ async def get_embedding_dim(): embedding_dim = embedding.shape[1] return embedding_dim + async def init(): - # Detect embedding dimension embedding_dimension = await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") @@ -91,36 +89,36 @@ async def init(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud + oracle_db = OracleDB( + config={ + "user": "", + "password": "", + "dsn": "", + "config_dir": "", + "wallet_location": "", + "wallet_password": "", + "workspace": "", + } # specify which docs you want to store and query + ) - oracle_db = OracleDB(config={ - "user":"", - "password":"", - "dsn":"", - "config_dir":"", - "wallet_location":"", - "wallet_password":"", - "workspace":"" - } # specify which docs you want to store and query - ) - # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage + # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - graph_storage = "OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" - ) + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage", + ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.graph_storage_cls.db = oracle_db @@ -129,6 +127,7 @@ async def init(): return rag + # Data models @@ -152,6 +151,7 @@ class Response(BaseModel): rag = None # 定义为全局对象 + @asynccontextmanager async def lifespan(app: FastAPI): global rag @@ -160,18 +160,21 @@ async def lifespan(app: FastAPI): yield -app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) +app = FastAPI( + title="LightRAG API", description="API for RAG operations", lifespan=lifespan +) + @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): try: # loop = asyncio.get_event_loop() result = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, only_need_context=request.only_need_context - ), - ) + request.query, + param=QueryParam( + mode=request.mode, only_need_context=request.only_need_context + ), + ) return Response(status="success", data=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -234,4 +237,4 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 3e196400..365b6225 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -1,11 +1,11 @@ -import sys, os +import sys +import os from pathlib import Path import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np -from datetime import datetime from lightrag.kg.oracle_impl import OracleDB print(os.getcwd()) @@ -25,6 +25,7 @@ EMBEDMODEL = "cohere.embed-multilingual-v3.0" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -66,22 +67,21 @@ async def main(): # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud - oracle_db = OracleDB(config={ - "user":"username", - "password":"xxxxxxxxx", - "dsn":"xxxxxxx_medium", - "config_dir":"dir/path/to/oracle/config", - "wallet_location":"dir/path/to/oracle/wallet", - "wallet_password":"xxxxxxxxx", - "workspace":"company" # specify which docs you want to store and query + oracle_db = OracleDB( + config={ + "user": "username", + "password": "xxxxxxxxx", + "dsn": "xxxxxxx_medium", + "config_dir": "dir/path/to/oracle/config", + "wallet_location": "dir/path/to/oracle/wallet", + "wallet_password": "xxxxxxxxx", + "workspace": "company", # specify which docs you want to store and query } - ) - + ) # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() - # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( @@ -93,10 +93,10 @@ async def main(): embedding_dim=embedding_dimension, max_token_size=512, func=embedding_func, - ), - graph_storage = "OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage", ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool @@ -106,18 +106,23 @@ async def main(): # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) + await rag.ainsert(f.read()) # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: - print("="*20, mode, "="*20) - print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode))) - print("-"*100, "\n") + print("=" * 20, mode, "=" * 20) + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode=mode), + ) + ) + print("-" * 100, "\n") except Exception as e: print(f"An error occurred: {e}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/lightrag/base.py b/lightrag/base.py index 379efeb3..46dfc800 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -60,6 +60,7 @@ class BaseVectorStorage(StorageNameSpace): @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): embedding_func: EmbeddingFunc + async def all_keys(self) -> list[str]: raise NotImplementedError @@ -85,6 +86,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): embedding_func: EmbeddingFunc = None + async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index e0671a71..96a9e795 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -1,9 +1,9 @@ import asyncio -#import html -#import os + +# import html +# import os from dataclasses import dataclass -from typing import Any, Union, cast -import networkx as nx +from typing import Union import numpy as np import array @@ -16,8 +16,9 @@ from ..base import ( import oracledb + class OracleDB: - def __init__(self,config,**kwargs): + def __init__(self, config, **kwargs): self.host = config.get("host", None) self.port = config.get("port", None) self.user = config.get("user", None) @@ -32,21 +33,21 @@ class OracleDB: logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") if self.user is None or self.password is None: raise ValueError("Missing database user or password in addon_params") - + try: oracledb.defaults.fetch_lobs = False self.pool = oracledb.create_pool_async( - user = self.user, - password = self.password, - dsn = self.dsn, - config_dir = self.config_dir, - wallet_location = self.wallet_location, - wallet_password = self.wallet_password, - min = 1, - max = self.max, - increment = self.increment - ) + user=self.user, + password=self.password, + dsn=self.dsn, + config_dir=self.config_dir, + wallet_location=self.wallet_location, + wallet_password=self.wallet_password, + min=1, + max=self.max, + increment=self.increment, + ) logger.info(f"Connected to Oracle database at {self.dsn}") except Exception as e: logger.error(f"Failed to connect to Oracle database at {self.dsn}") @@ -90,12 +91,14 @@ class OracleDB: arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - + async def check_tables(self): - for k,v in TABLES.items(): + for k, v in TABLES.items(): try: if k.lower() == "lightrag_graph": - await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only") + await self.query( + "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only" + ) else: await self.query("SELECT 1 FROM {k}".format(k=k)) except Exception as e: @@ -108,12 +111,11 @@ class OracleDB: except Exception as e: logger.error(f"Failed to create table {k} in Oracle database") logger.error(f"Oracle database error: {e}") - - logger.info(f"Finished check all tables in Oracle database") - - - async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]: - async with self.pool.acquire() as connection: + + logger.info("Finished check all tables in Oracle database") + + async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: + async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler with connection.cursor() as cursor: @@ -136,9 +138,9 @@ class OracleDB: data = dict(zip(columns, row)) else: data = None - return data + return data - async def execute(self,sql: str, data: list = None): + async def execute(self, sql: str, data: list = None): # logger.info("go into OracleDB execute method") try: async with self.pool.acquire() as connection: @@ -148,58 +150,63 @@ class OracleDB: if data is None: await cursor.execute(sql) else: - #print(data) - #print(sql) - await cursor.execute(sql,data) + # print(data) + # print(sql) + await cursor.execute(sql, data) await connection.commit() except Exception as e: - logger.error(f"Oracle database error: {e}") + logger.error(f"Oracle database error: {e}") print(sql) print(data) raise + @dataclass class OracleKVStorage(BaseKVStorage): - # should pass db object to self.db def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] - + self._max_batch_size = self.global_config["embedding_batch_num"] + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" - SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id) - #print("get_by_id:"+SQL) - res = await self.db.query(SQL) + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( + workspace=self.db.workspace, id=id + ) + # print("get_by_id:"+SQL) + res = await self.db.query(SQL) if res: - data = res #{"data":res} - #print (data) + data = res # {"data":res} + # print (data) return data else: return None # Query by id - async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]: + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace, - ids=",".join([f"'{id}'" for id in ids])) - #print("get_by_ids:"+SQL) - res = await self.db.query(SQL,multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) + ) + # print("get_by_ids:"+SQL) + res = await self.db.query(SQL, multirows=True) if res: - data = res # [{"data":i} for i in res] - #print(data) + data = res # [{"data":i} for i in res] + # print(data) return data else: return None - + async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], - workspace=self.db.workspace, - ids=",".join([f"'{k}'" for k in keys])) - res = await self.db.query(SQL,multirows=True) + SQL = SQL_TEMPLATES["filter_keys"].format( + table_name=N_T[self.namespace], + workspace=self.db.workspace, + ids=",".join([f"'{k}'" for k in keys]), + ) + res = await self.db.query(SQL, multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -208,14 +215,13 @@ class OracleKVStorage(BaseKVStorage): exist_keys = [] data = set([s for s in keys if s not in exist_keys]) return data - - + ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) - #print(self._data) - #values = [] + # print(self._data) + # values = [] if self.namespace == "text_chunks": list_data = [ { @@ -226,7 +232,7 @@ class OracleKVStorage(BaseKVStorage): ] contents = [v["content"] for v in data.values()] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -235,42 +241,45 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - #print(list_data) + # print(list_data) for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"].format( - check_id=item["__id__"] - ) + merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) - values = [item["__id__"], item["content"], self.db.workspace, item["tokens"], - item["chunk_order_index"], item["full_doc_id"], item["__vector__"]] - #print(merge_sql) + values = [ + item["__id__"], + item["content"], + self.db.workspace, + item["tokens"], + item["chunk_order_index"], + item["full_doc_id"], + item["__vector__"], + ] + # print(merge_sql) await self.db.execute(merge_sql, values) if self.namespace == "full_docs": for k, v in self._data.items(): - #values.clear() + # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"].format( check_id=k, ) values = [k, self._data[k]["content"], self.db.workspace] - #print(merge_sql) + # print(merge_sql) await self.db.execute(merge_sql, values) return left_data - async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: logger.info("full doc and chunk data had been saved into oracle db!") - @dataclass class OracleVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = 0.2 def __post_init__(self): pass - + async def upsert(self, data: dict[str, dict]): """向向量数据库中插入数据""" pass @@ -278,53 +287,51 @@ class OracleVectorDBStorage(BaseVectorStorage): async def index_done_callback(self): pass - #################### query method ############### async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + """从向量数据库中查询数据""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = ', '.join(map(str, embedding.tolist())) + embedding_string = ", ".join(map(str, embedding.tolist())) SQL = SQL_TEMPLATES[self.namespace].format( - embedding_string=embedding_string, - dimension=dimension, - dtype=dtype, - workspace=self.db.workspace, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) + embedding_string=embedding_string, + dimension=dimension, + dtype=dtype, + workspace=self.db.workspace, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) # print(SQL) results = await self.db.query(SQL, multirows=True) - #print("vector search result:",results) + # print("vector search result:",results) return results @dataclass -class OracleGraphStorage(BaseGraphStorage): +class OracleGraphStorage(BaseGraphStorage): """基于Oracle的图存储模块""" - + def __post_init__(self): """从graphml文件加载图""" self._max_batch_size = self.global_config["embedding_batch_num"] - #################### insert method ################ - + async def upsert_node(self, node_id: str, node_data: dict[str, str]): """插入或更新节点""" - #print("go into upsert node method") + # print("go into upsert node method") entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] - source_id = node_data["source_id"] - content = entity_name+description + source_id = node_data["source_id"] + content = entity_name + description contents = [content] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -333,27 +340,38 @@ class OracleGraphStorage(BaseGraphStorage): embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] merge_sql = SQL_TEMPLATES["merge_node"].format( - workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id + workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id ) - #print(merge_sql) - await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector]) - #self._graph.add_node(node_id, **node_data) + # print(merge_sql) + await self.db.execute( + merge_sql, + [ + self.db.workspace, + entity_name, + entity_type, + description, + source_id, + content, + content_vector, + ], + ) + # self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): """插入或更新边""" - #print("go into upsert edge method") + # print("go into upsert edge method") source_name = source_node_id target_name = target_node_id weight = edge_data["weight"] keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] - content = keywords+source_name+target_name+description + content = keywords + source_name + target_name + description contents = [content] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -362,11 +380,27 @@ class OracleGraphStorage(BaseGraphStorage): embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] merge_sql = SQL_TEMPLATES["merge_edge"].format( - workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id + workspace=self.db.workspace, + source_name=source_name, + target_name=target_name, + source_chunk_id=source_chunk_id, ) - #print(merge_sql) - await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector]) - #self._graph.add_edge(source_node_id, target_node_id, **edge_data) + # print(merge_sql) + await self.db.execute( + merge_sql, + [ + self.db.workspace, + source_name, + target_name, + weight, + keywords, + description, + source_chunk_id, + content, + content_vector, + ], + ) + # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: """为节点生成向量""" @@ -386,99 +420,109 @@ class OracleGraphStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - async def index_done_callback(self): """写入graphhml图文件""" - logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!") - + logger.info( + "Node and edge data had been saved into oracle db already, so nothing to do here!" + ) + #################### query method ################# async def has_node(self, node_id: str) -> bool: - """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id) - # print(SQL) - #print(self.db.workspace, node_id) + """根据节点id检查节点是否存在""" + SQL = SQL_TEMPLATES["has_node"].format( + workspace=self.db.workspace, node_id=node_id + ) + # print(SQL) + # print(self.db.workspace, node_id) res = await self.db.query(SQL) if res: - #print("Node exist!",res) + # print("Node exist!",res) return True else: - #print("Node not exist!") + # print("Node not exist!") return False async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """根据源和目标节点id检查边是否存在""" - SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id) + SQL = SQL_TEMPLATES["has_edge"].format( + workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id, + ) # print(SQL) res = await self.db.query(SQL) if res: - #print("Edge exist!",res) + # print("Edge exist!",res) return True else: - #print("Edge not exist!") + # print("Edge not exist!") return False async def node_degree(self, node_id: str) -> int: - """根据节点id获取节点的度""" - SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id) + """根据节点id获取节点的度""" + SQL = SQL_TEMPLATES["node_degree"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(SQL) res = await self.db.query(SQL) if res: - #print("Node degree",res["degree"]) + # print("Node degree",res["degree"]) return res["degree"] else: - #print("Edge not exist!") + # print("Edge not exist!") return 0 - async def edge_degree(self, src_id: str, tgt_id: str) -> int: """根据源和目标节点id获取边的度""" degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) - #print("Edge degree",degree) + # print("Edge degree",degree) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id) + SQL = SQL_TEMPLATES["get_node"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(self.db.workspace, node_id) # print(SQL) res = await self.db.query(SQL) if res: - #print("Get node!",self.db.workspace, node_id,res) + # print("Get node!",self.db.workspace, node_id,res) return res else: - #print("Can't get node!",self.db.workspace, node_id) + # print("Can't get node!",self.db.workspace, node_id) return None - + async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """根据源和目标节点id获取边""" - SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id) + SQL = SQL_TEMPLATES["get_edge"].format( + workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id, + ) res = await self.db.query(SQL) if res: - #print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) + # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res else: - #print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) + # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) return None async def get_node_edges(self, source_node_id: str): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace, - source_node_id=source_node_id) + SQL = SQL_TEMPLATES["get_node_edges"].format( + workspace=self.db.workspace, source_node_id=source_node_id + ) res = await self.db.query(sql=SQL, multirows=True) if res: - data = [(i["source_name"],i["target_name"]) for i in res] - #print("Get node edge!",self.db.workspace, source_node_id,data) + data = [(i["source_name"], i["target_name"]) for i in res] + # print("Get node edge!",self.db.workspace, source_node_id,data) return data else: - #print("Node Edge not exist!",self.db.workspace, source_node_id) + # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] @@ -487,12 +531,12 @@ N_T = { "text_chunks": "LIGHTRAG_DOC_CHUNKS", "chunks": "LIGHTRAG_DOC_CHUNKS", "entities": "LIGHTRAG_GRAPH_NODES", - "relationships": "LIGHTRAG_GRAPH_EDGES" + "relationships": "LIGHTRAG_GRAPH_EDGES", } TABLES = { - "LIGHTRAG_DOC_FULL": - {"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL ( + "LIGHTRAG_DOC_FULL": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( id varchar(256)PRIMARY KEY, workspace varchar(1024), doc_name varchar(1024), @@ -500,61 +544,63 @@ TABLES = { meta JSON, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - - "LIGHTRAG_DOC_CHUNKS": - {"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + )""" + }, + "LIGHTRAG_DOC_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( id varchar(256) PRIMARY KEY, workspace varchar(1024), full_doc_id varchar(256), chunk_order_index NUMBER, - tokens NUMBER, + tokens NUMBER, content CLOB, content_vector VECTOR, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL - )"""}, - - "LIGHTRAG_GRAPH_NODES": - {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES ( + updatetime TIMESTAMP DEFAULT NULL + )""" + }, + "LIGHTRAG_GRAPH_NODES": { + "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES ( id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, workspace varchar(1024), name varchar(2048), - entity_type varchar(1024), + entity_type varchar(1024), description CLOB, source_chunk_id varchar(256), content CLOB, content_vector VECTOR, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - "LIGHTRAG_GRAPH_EDGES": - {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES ( + )""" + }, + "LIGHTRAG_GRAPH_EDGES": { + "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES ( id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, workspace varchar(1024), source_name varchar(2048), - target_name varchar(2048), + target_name varchar(2048), weight NUMBER, - keywords CLOB, + keywords CLOB, description CLOB, source_chunk_id varchar(256), content CLOB, content_vector VECTOR, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - "LIGHTRAG_LLM_CACHE": - {"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE ( + )""" + }, + "LIGHTRAG_LLM_CACHE": { + "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( id varchar(256) PRIMARY KEY, send clob, return clob, model varchar(1024), createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - - "LIGHTRAG_GRAPH": - {"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph + )""" + }, + "LIGHTRAG_GRAPH": { + "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph VERTEX TABLES ( lightrag_graph_nodes KEY (id) LABEL entity @@ -565,93 +611,67 @@ TABLES = { SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name) DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name) LABEL has_relation - PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) - ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""}, - } + PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) + ) OPTIONS(ALLOW MIXED PROPERTY TYPES)""" + }, +} SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": - "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", - - "get_by_id_text_chunks": - "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", - - "get_by_ids_full_docs": - "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", - - "get_by_ids_text_chunks": - "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", - - "filter_keys": - "select id from {table_name} where workspace='{workspace}' and id in ({ids})", - + "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})", "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = '{check_id}') WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:1,:2,:3) """, - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a USING DUAL ON (a.id = '{check_id}') WHEN NOT MATCHED THEN INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) values (:1,:2,:3,:4,:5,:6,:7) """, - # SQL for VectorStorage - "entities": - """SELECT name as entity_name FROM - (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') + "entities": """SELECT name as entity_name FROM + (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", - - "relationships": - """SELECT source_name as src_id, target_name as tgt_id FROM - (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') + "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM + (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", - - "chunks": - """SELECT id FROM - (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') + "chunks": """SELECT id FROM + (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", - # SQL for GraphStorage - "has_node": - """SELECT * FROM GRAPH_TABLE (lightrag_graph + "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) WHERE a.workspace='{workspace}' AND a.name='{node_id}' COLUMNS (a.name))""", - - "has_edge": - """SELECT * FROM GRAPH_TABLE (lightrag_graph + "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) -[e]-> (b) WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' AND a.name='{source_node_id}' AND b.name='{target_node_id}' COLUMNS (e.source_name,e.target_name) )""", - - "node_degree": - """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph + "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' AND a.name='{node_id}' or b.name = '{node_id}' COLUMNS (a.name))""", - - "get_node": - """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description + "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description FROM GRAPH_TABLE (lightrag_graph - MATCH (a) + MATCH (a) WHERE a.workspace='{workspace}' AND a.name='{node_id}' COLUMNS (a.name) ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name WHERE t2.workspace='{workspace}'""", - - "get_edge": - """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, + "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) @@ -659,15 +679,12 @@ SQL_TEMPLATES = { AND a.name='{source_node_id}' and b.name = '{target_node_id}' COLUMNS (e.id,a.name as source_id) ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", - - "get_node_edges": - """SELECT source_name,target_name + "get_node_edges": """SELECT source_name,target_name FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' AND a.name='{source_node_id}' COLUMNS (a.name as source_name,b.name as target_name))""", - "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') @@ -679,5 +696,5 @@ SQL_TEMPLATES = { ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """ - } \ No newline at end of file + values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, +} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 52786970..50e33405 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -38,15 +38,11 @@ from .storage import ( JsonKVStorage, NanoVectorDBStorage, NetworkXStorage, - ) +) from .kg.neo4j_impl import Neo4JStorage -from .kg.oracle_impl import ( - OracleKVStorage, - OracleGraphStorage, - OracleVectorDBStorage - ) +from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage # future KG integrations @@ -54,6 +50,7 @@ from .kg.oracle_impl import ( # GraphStorage as ArangoDBStorage # ) + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: return asyncio.get_event_loop() @@ -72,7 +69,7 @@ class LightRAG: default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) - kv_storage : str = field(default="JsonKVStorage") + kv_storage: str = field(default="JsonKVStorage") vector_storage: str = field(default="NanoVectorDBStorage") graph_storage: str = field(default="NetworkXStorage") @@ -115,7 +112,7 @@ class LightRAG: # storage vector_db_storage_cls_kwargs: dict = field(default_factory=dict) - + enable_llm_cache: bool = True # extension @@ -134,18 +131,25 @@ class LightRAG: # @TODO: should move all storage setup here to leverage initial start params attached to self. - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage] - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage] - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage] + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class()[self.kv_storage] + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ + self.vector_storage + ] + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ + self.graph_storage + ] if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.llm_response_cache = ( self.key_string_value_json_storage_cls( - namespace="llm_response_cache", global_config=asdict(self),embedding_func=None + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, ) if self.enable_llm_cache else None @@ -159,13 +163,19 @@ class LightRAG: # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( - namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func + namespace="full_docs", + global_config=asdict(self), + embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( - namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func + namespace="text_chunks", + global_config=asdict(self), + embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func + namespace="chunk_entity_relation", + global_config=asdict(self), + embedding_func=self.embedding_func, ) #### # add embedding func by walter over @@ -200,13 +210,11 @@ class LightRAG: def _get_storage_class(self) -> Type[BaseGraphStorage]: return { # kv storage - "JsonKVStorage":JsonKVStorage, - "OracleKVStorage":OracleKVStorage, - + "JsonKVStorage": JsonKVStorage, + "OracleKVStorage": OracleKVStorage, # vector storage - "NanoVectorDBStorage":NanoVectorDBStorage, - "OracleVectorDBStorage":OracleVectorDBStorage, - + "NanoVectorDBStorage": NanoVectorDBStorage, + "OracleVectorDBStorage": OracleVectorDBStorage, # graph storage "NetworkXStorage": NetworkXStorage, "Neo4JStorage": Neo4JStorage, diff --git a/lightrag/operate.py b/lightrag/operate.py index 5288c26d..db7c9401 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -16,7 +16,7 @@ from .utils import ( split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, - locate_json_string_body_from_string + locate_json_string_body_from_string, ) from .base import ( BaseGraphStorage, diff --git a/requirements.txt b/requirements.txt index d57de091..6adb6929 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,22 @@ accelerate +aioboto3 aiohttp + +# database packages +graspologic +hnswlib +nano-vectordb +neo4j +networkx +ollama +openai +oracledb pyvis tenacity -xxhash # lmdeploy[all] # LLM packages tiktoken torch transformers -aioboto3 -ollama -openai - -# database packages -graspologic -hnswlib -networkx -oracledb -nano-vectordb -neo4j \ No newline at end of file +xxhash