From 1bc4e2382b4761d13d83363edfe82f5a470acccd Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:58:41 +0800 Subject: [PATCH 01/14] Oracle Database support Add oracle 23ai database as the KV/vector/graph storage --- examples/lightrag_oracle_demo.py | 127 +++++ lightrag/base.py | 2 + lightrag/kg/oracle_impl.py | 767 +++++++++++++++++++++++++++++++ lightrag/lightrag.py | 90 ++-- lightrag/prompt.py | 20 +- requirements.txt | 1 + 6 files changed, 972 insertions(+), 35 deletions(-) create mode 100644 examples/lightrag_oracle_demo.py create mode 100644 lightrag/kg/oracle_impl.py diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py new file mode 100644 index 00000000..94a47965 --- /dev/null +++ b/examples/lightrag_oracle_demo.py @@ -0,0 +1,127 @@ + + +import sys, os +print(os.getcwd()) +from pathlib import Path +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) + +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np +from datetime import datetime + +from lightrag.kg.oracle_impl import OracleDB + + +WORKING_DIR = "./dickens" + +# We use OpenAI compatible API to call LLM on Oracle Cloud +# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway +BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" +APIKEY = "ocigenerativeai" +CHATMODEL = "cohere.command-r-plus" +EMBEDMODEL = "cohere.embed-multilingual-v3.0" + + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + CHATMODEL, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=APIKEY, + base_url=BASE_URL, + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model=EMBEDMODEL, + api_key=APIKEY, + base_url=BASE_URL, + ) + + +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + return embedding_dim + + +async def main(): + try: + # Detect embedding dimension + embedding_dimension = await get_embedding_dim() + print(f"Detected embedding dimension: {embedding_dimension}") + + # Create Oracle DB connection + # The `config` parameter is the connection configuration of Oracle DB + # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html + # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query + # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud + oracle_db = OracleDB(config={ + "user":"RAG", + "password":"xxxxxxxxx", + "dsn":"xxxxxxx_medium", + "config_dir":"dir/path/to/oracle/config", + "wallet_location":"dir/path/to/oracle/wallet", + "wallet_password":"xxxxxxxxx", + "workspace":"company" # specify which docs we want to store and query + } + ) + + + # Check if Oracle DB tables exist, if not, tables will be created + await oracle_db.check_tables() + + + # Initialize LightRAG + # We use Oracle DB as the KV/vector/graph storage + rag = LightRAG( + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage = "OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage" + ) + + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool + rag.graph_storage_cls.db = oracle_db + rag.key_string_value_json_storage_cls.db = oracle_db + rag.vector_db_storage_cls.db = oracle_db + + # Extract and Insert into LightRAG storage + with open("./dickens/demo.txt", "r", encoding="utf-8") as f: + await rag.ainsert(f.read()) + + # Perform search in different modes + modes = ["naive", "local", "global", "hybrid"] + for mode in modes: + print("="*20, mode, "="*20) + print(await rag.aquery("这个文章讲了什么?", param=QueryParam(mode=mode))) + print("-"*100, "\n") + + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/lightrag/base.py b/lightrag/base.py index cecd5edd..97524472 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -59,6 +59,7 @@ class BaseVectorStorage(StorageNameSpace): @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): + embedding_func: EmbeddingFunc async def all_keys(self) -> list[str]: raise NotImplementedError @@ -83,6 +84,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): + embedding_func: EmbeddingFunc async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py new file mode 100644 index 00000000..ab6312a7 --- /dev/null +++ b/lightrag/kg/oracle_impl.py @@ -0,0 +1,767 @@ +import asyncio +#import html +#import os +from dataclasses import dataclass +from typing import Any, Union, cast +import networkx as nx +import numpy as np +import array + +from ..utils import logger +from ..base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, +) + +import oracledb + +class OracleDB: + def __init__(self,config,**kwargs): + self.host = config.get("host", None) + self.port = config.get("port", None) + self.user = config.get("user", None) + self.password = config.get("password", None) + self.dsn = config.get("dsn", None) + self.config_dir = config.get("config_dir", None) + self.wallet_location = config.get("wallet_location", None) + self.wallet_password = config.get("wallet_password", None) + self.workspace = config.get("workspace", None) + self.max = 12 + self.increment = 1 + logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") + if self.user is None or self.password is None: + raise ValueError("Missing database user or password in addon_params") + + try: + oracledb.defaults.fetch_lobs = False + + self.pool = oracledb.create_pool_async( + user = self.user, + password = self.password, + dsn = self.dsn, + config_dir = self.config_dir, + wallet_location = self.wallet_location, + wallet_password = self.wallet_password, + min = 1, + max = self.max, + increment = self.increment + ) + logger.info(f"Connected to Oracle database at {self.dsn}") + except Exception as e: + logger.error(f"Failed to connect to Oracle database at {self.dsn}") + logger.error(f"Oracle database error: {e}") + raise + + def numpy_converter_in(self, value): + """Convert numpy array to array.array""" + if value.dtype == np.float64: + dtype = "d" + elif value.dtype == np.float32: + dtype = "f" + else: + dtype = "b" + return array.array(dtype, value) + + def input_type_handler(self, cursor, value, arraysize): + """Set the type handler for the input data""" + if isinstance(value, np.ndarray): + return cursor.var( + oracledb.DB_TYPE_VECTOR, + arraysize=arraysize, + inconverter=self.numpy_converter_in, + ) + + def numpy_converter_out(self, value): + """Convert array.array to numpy array""" + if value.typecode == "b": + dtype = np.int8 + elif value.typecode == "f": + dtype = np.float32 + else: + dtype = np.float64 + return np.array(value, copy=False, dtype=dtype) + + def output_type_handler(self, cursor, metadata): + """Set the type handler for the output data""" + if metadata.type_code is oracledb.DB_TYPE_VECTOR: + return cursor.var( + metadata.type_code, + arraysize=cursor.arraysize, + outconverter=self.numpy_converter_out, + ) + + async def check_tables(self): + for k,v in TABLES.items(): + try: + if k.lower() == "lightrag_graph": + await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only") + else: + await self.query("SELECT 1 FROM {k}".format(k=k)) + except Exception as e: + logger.error(f"Failed to check table {k} in Oracle database") + logger.error(f"Oracle database error: {e}") + try: + # print(v["ddl"]) + await self.execute(v["ddl"]) + logger.info(f"Created table {k} in Oracle database") + except Exception as e: + logger.error(f"Failed to create table {k} in Oracle database") + logger.error(f"Oracle database error: {e}") + + logger.info(f"Finished check all tables in Oracle database") + + + async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]: + async with self.pool.acquire() as connection: + connection.inputtypehandler = self.input_type_handler + connection.outputtypehandler = self.output_type_handler + with connection.cursor() as cursor: + try: + await cursor.execute(sql) + except Exception as e: + logger.error(f"Oracle database error: {e}") + print(sql) + raise + columns = [column[0].lower() for column in cursor.description] + if multirows: + rows = await cursor.fetchall() + if rows: + data = [dict(zip(columns, row)) for row in rows] + else: + data = [] + else: + row = await cursor.fetchone() + if row: + data = dict(zip(columns, row)) + else: + data = None + return data + + async def execute(self,sql: str, data: list = None): + # logger.info("go into OracleDB execute method") + try: + async with self.pool.acquire() as connection: + connection.inputtypehandler = self.input_type_handler + connection.outputtypehandler = self.output_type_handler + with connection.cursor() as cursor: + if data is None: + await cursor.execute(sql) + else: + #print(data) + #print(sql) + await cursor.execute(sql,data) + await connection.commit() + except Exception as e: + logger.error(f"Oracle database error: {e}") + print(sql) + print(data) + raise + +@dataclass +class OracleKVStorage(BaseKVStorage): + + # should pass db object to self.db + def __post_init__(self): + self._data = {} + self._max_batch_size = self.global_config["embedding_batch_num"] + + ################ QUERY METHODS ################ + + async def get_by_id(self, id: str) -> Union[dict, None]: + """根据 id 获取 doc_full 数据.""" + SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id) + #print("get_by_id:"+SQL) + res = await self.db.query(SQL) + if res: + data = res #{"data":res} + #print (data) + return data + else: + return None + + # Query by id + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]: + """根据 id 获取 doc_chunks 数据""" + SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace, + ids=",".join([f"'{id}'" for id in ids])) + #print("get_by_ids:"+SQL) + res = await self.db.query(SQL,multirows=True) + if res: + data = res # [{"data":i} for i in res] + #print(data) + return data + else: + return None + + async def filter_keys(self, keys: list[str]) -> set[str]: + """过滤掉重复内容""" + SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], + workspace=self.db.workspace, + ids=",".join([f"'{k}'" for k in keys])) + res = await self.db.query(SQL,multirows=True) + data = None + if res: + exist_keys = [key["id"] for key in res] + data = set([s for s in keys if s not in exist_keys]) + else: + exist_keys = [] + data = set([s for s in keys if s not in exist_keys]) + return data + + + ################ INSERT METHODS ################ + async def upsert(self, data: dict[str, dict]): + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) + #print(self._data) + #values = [] + if self.namespace == "text_chunks": + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items()}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i: i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + #print(list_data) + for item in list_data: + merge_sql = SQL_TEMPLATES["merge_chunk"].format( + check_id=item["__id__"] + ) + + values = [item["__id__"], item["content"], self.db.workspace, item["tokens"], + item["chunk_order_index"], item["full_doc_id"], item["__vector__"]] + #print(merge_sql) + await self.db.execute(merge_sql, values) + + if self.namespace == "full_docs": + for k, v in self._data.items(): + #values.clear() + merge_sql = SQL_TEMPLATES["merge_doc_full"].format( + check_id=k, + ) + values = [k, self._data[k]["content"], self.db.workspace] + #print(merge_sql) + await self.db.execute(merge_sql, values) + return left_data + + + async def index_done_callback(self): + if self.namespace in ["full_docs", "text_chunks"]: + logger.info("full doc and chunk data had been saved into oracle db!") + + + +@dataclass +class OracleVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + pass + + async def upsert(self, data: dict[str, dict]): + """向向量数据库中插入数据""" + pass + + async def index_done_callback(self): + pass + + + #################### query method ################ + async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: + """从向量数据库中查询数据""" + embeddings = await self.embedding_func([query]) + embedding = embeddings[0] + # 转换精度 + dtype = str(embedding.dtype).upper() + dimension = embedding.shape[0] + embedding_string = ', '.join(map(str, embedding.tolist())) + + SQL = SQL_TEMPLATES[self.namespace].format( + embedding_string=embedding_string, + dimension=dimension, + dtype=dtype, + workspace=self.db.workspace, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + # print(SQL) + results = await self.db.query(SQL, multirows=True) + #print("vector search result:",results) + return results + + +@dataclass +class OracleGraphStorage(BaseGraphStorage): + """基于Oracle的图存储模块""" + # @staticmethod + # def load_graph(file_name) -> nx.Graph: + # """读取graphhml图文件""" + + # @staticmethod + # def write_graph(graph: nx.Graph, file_name): + # # """写入graphhml图文件""" + + # @staticmethod + # def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + # """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + # Return the largest connected component of the graph, with nodes and edges sorted in a stable way. + # 用于产生稳定的最大连通分量的模块,即相同的输入图==相同的输出lcc。 + # """ + + + # @staticmethod + # def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + # """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + # Ensure an undirected graph with the same relationships will always be read the same way. + # 确保具有相同关系的无向图始终以相同的方式读取。 + # """ + + def __post_init__(self): + """从graphml文件加载图""" + self._max_batch_size = self.global_config["embedding_batch_num"] + + + #################### insert method ################ + + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """插入或更新节点""" + #print("go into upsert node method") + entity_name = node_id + entity_type = node_data["entity_type"] + description = node_data["description"] + source_id = node_data["source_id"] + content = entity_name+description + contents = [content] + batches = [ + contents[i: i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + content_vector = embeddings[0] + merge_sql = SQL_TEMPLATES["merge_node"].format( + workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id + ) + #print(merge_sql) + await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector]) + #self._graph.add_node(node_id, **node_data) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + """插入或更新边""" + #print("go into upsert edge method") + source_name = source_node_id + target_name = target_node_id + weight = edge_data["weight"] + keywords = edge_data["keywords"] + description = edge_data["description"] + source_chunk_id = edge_data["source_id"] + content = keywords+source_name+target_name+description + contents = [content] + batches = [ + contents[i: i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + content_vector = embeddings[0] + merge_sql = SQL_TEMPLATES["merge_edge"].format( + workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id + ) + #print(merge_sql) + await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector]) + #self._graph.add_edge(source_node_id, target_node_id, **edge_data) + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """为节点生成向量""" + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() + + async def _node2vec_embed(self): + """为节点生成向量""" + from graspologic import embed + + embeddings, nodes = embed.node2vec_embed( + self._graph, + **self.config["node2vec_params"], + ) + + nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + return embeddings, nodes_ids + + + async def index_done_callback(self): + """写入graphhml图文件""" + logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!") + + #################### query method ################ + async def has_node(self, node_id: str) -> bool: + """根据节点id检查节点是否存在""" + SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id) + # print(SQL) + #print(self.db.workspace, node_id) + res = await self.db.query(SQL) + if res: + #print("Node exist!",res) + return True + else: + #print("Node not exist!") + return False + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """根据源和目标节点id检查边是否存在""" + SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id) + # print(SQL) + res = await self.db.query(SQL) + if res: + #print("Edge exist!",res) + return True + else: + #print("Edge not exist!") + return False + + async def node_degree(self, node_id: str) -> int: + """根据节点id获取节点的度""" + SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id) + # print(SQL) + res = await self.db.query(SQL) + if res: + #print("Node degree",res["degree"]) + return res["degree"] + else: + #print("Edge not exist!") + return 0 + + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """根据源和目标节点id获取边的度""" + degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) + #print("Edge degree",degree) + return degree + + + async def get_node(self, node_id: str) -> Union[dict, None]: + """根据节点id获取节点数据""" + SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id) + # print(self.db.workspace, node_id) + # print(SQL) + res = await self.db.query(SQL) + if res: + #print("Get node!",self.db.workspace, node_id,res) + return res + else: + #print("Can't get node!",self.db.workspace, node_id) + return None + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + """根据源和目标节点id获取边""" + SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id) + res = await self.db.query(SQL) + if res: + #print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) + return res + else: + #print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) + return None + + async def get_node_edges(self, source_node_id: str): + """根据节点id获取节点的所有边""" + if await self.has_node(source_node_id): + SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace, + source_node_id=source_node_id) + res = await self.db.query(sql=SQL, multirows=True) + if res: + data = [(i["source_name"],i["target_name"]) for i in res] + #print("Get node edge!",self.db.workspace, source_node_id,data) + return data + else: + #print("Node Edge not exist!",self.db.workspace, source_node_id) + return [] + + #################### INSERT method ################ + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """插入或更新节点""" + #print("go into upsert node method") + entity_name = node_id + entity_type = node_data["entity_type"] + description = node_data["description"] + source_id = node_data["source_id"] + content = entity_name+description + contents = [content] + batches = [ + contents[i: i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + content_vector = embeddings[0] + merge_sql = SQL_TEMPLATES["merge_node"].format( + workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id + ) + #print(merge_sql) + await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector]) + #self._graph.add_node(node_id, **node_data) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + """插入或更新边""" + #print("go into upsert edge method") + source_name = source_node_id + target_name = target_node_id + weight = edge_data["weight"] + keywords = edge_data["keywords"] + description = edge_data["description"] + source_chunk_id = edge_data["source_id"] + content = keywords+source_name+target_name+description + contents = [content] + batches = [ + contents[i: i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + content_vector = embeddings[0] + merge_sql = SQL_TEMPLATES["merge_edge"].format( + workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id + ) + #print(merge_sql) + await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector]) + #self._graph.add_edge(source_node_id, target_node_id, **edge_data) + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """为节点生成向量""" + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() + + async def _node2vec_embed(self): + """为节点生成向量""" + from graspologic import embed + + embeddings, nodes = embed.node2vec_embed( + self._graph, + **self.config["node2vec_params"], + ) + + nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + return embeddings, nodes_ids + + +N_T = { + "full_docs": "LIGHTRAG_DOC_FULL", + "text_chunks": "LIGHTRAG_DOC_CHUNKS", + "chunks": "LIGHTRAG_DOC_CHUNKS", + "entities": "LIGHTRAG_GRAPH_NODES", + "relationships": "LIGHTRAG_GRAPH_EDGES" +} + +TABLES = { + "LIGHTRAG_DOC_FULL": + {"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL ( + id varchar(256)PRIMARY KEY, + workspace varchar(1024), + doc_name varchar(1024), + content CLOB, + meta JSON + )"""}, + + "LIGHTRAG_DOC_CHUNKS": + {"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + id varchar(256) PRIMARY KEY, + workspace varchar(1024), + full_doc_id varchar(256), + chunk_order_index NUMBER, + tokens NUMBER, + content CLOB, + content_vector VECTOR + )"""}, + + "LIGHTRAG_GRAPH_NODES": + {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES ( + id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + workspace varchar(1024), + name varchar(2048), + entity_type varchar(1024), + description CLOB, + source_chunk_id varchar(256), + content CLOB, + content_vector VECTOR + )"""}, + "LIGHTRAG_GRAPH_EDGES": + {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES ( + id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + workspace varchar(1024), + source_name varchar(2048), + target_name varchar(2048), + weight NUMBER, + keywords CLOB, + description CLOB, + source_chunk_id varchar(256), + content CLOB, + content_vector VECTOR + )"""}, + "LIGHTRAG_LLM_CACHE": + {"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE ( + id varchar(256) PRIMARY KEY, + return clob, + model varchar(1024) + )"""}, + + "LIGHTRAG_GRAPH": + {"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph + VERTEX TABLES ( + lightrag_graph_nodes KEY (id) + LABEL entity + PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id) + ) + EDGE TABLES ( + lightrag_graph_edges KEY (id) + SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name) + DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name) + LABEL has_relation + PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) + ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""}, + } + + +SQL_TEMPLATES = { + # SQL for KVStorage + "get_by_id_full_docs": + "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", + + "get_by_id_text_chunks": + "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", + + "get_by_ids_full_docs": + "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", + + "get_by_ids_text_chunks": + "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", + + "filter_keys": + "select id from {table_name} where workspace='{workspace}' and id in ({ids})", + + "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = '{check_id}') + WHEN NOT MATCHED THEN + INSERT(id,content,workspace) values(:1,:2,:3) + """, + + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a + USING DUAL + ON (a.id = '{check_id}') + WHEN NOT MATCHED THEN + INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) + values (:1,:2,:3,:4,:5,:6,:7) """, + + # SQL for VectorStorage + "entities": + """SELECT name as entity_name FROM + (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') + WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + + "relationships": + """SELECT source_name as src_id, target_name as tgt_id FROM + (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') + WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + + "chunks": + """SELECT id FROM + (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') + WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + + # SQL for GraphStorage + "has_node": + """SELECT * FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace='{workspace}' AND a.name='{node_id}' + COLUMNS (a.name))""", + + "has_edge": + """SELECT * FROM GRAPH_TABLE (lightrag_graph + MATCH (a) -[e]-> (b) + WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' + AND a.name='{source_node_id}' AND b.name='{target_node_id}' + COLUMNS (e.source_name,e.target_name) )""", + + "node_degree": + """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph + MATCH (a)-[e]->(b) + WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' + AND a.name='{node_id}' or b.name = '{node_id}' + COLUMNS (a.name))""", + + "get_node": + """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description + FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace='{workspace}' AND a.name='{node_id}' + COLUMNS (a.name) + ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name + WHERE t2.workspace='{workspace}'""", + + "get_edge": + """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, + NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords + FROM GRAPH_TABLE (lightrag_graph + MATCH (a)-[e]->(b) + WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' + AND a.name='{source_node_id}' and b.name = '{target_node_id}' + COLUMNS (e.id,a.name as source_id) + ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", + + "get_node_edges": + """SELECT source_name,target_name + FROM GRAPH_TABLE (lightrag_graph + MATCH (a)-[e]->(b) + WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' + AND a.name='{source_node_id}' + COLUMNS (a.name as source_name,b.name as target_name))""", + + "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a + USING DUAL + ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') + WHEN NOT MATCHED THEN + INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) + values (:1,:2,:3,:4,:5,:6,:7) """, + "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a + USING DUAL + ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') + WHEN NOT MATCHED THEN + INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) + values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """ + } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2ae59f3b..d6a82d71 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -18,20 +18,6 @@ from .operate import ( naive_query, ) -from .storage import ( - JsonKVStorage, - NanoVectorDBStorage, - NetworkXStorage, -) - -from .kg.neo4j_impl import Neo4JStorage -# future KG integrations - -# from .kg.ArangoDB_impl import ( -# GraphStorage as ArangoDBStorage -# ) - - from .utils import ( EmbeddingFunc, compute_mdhash_id, @@ -49,6 +35,26 @@ from .base import ( ) +from .storage import ( + JsonKVStorage, + NanoVectorDBStorage, + NetworkXStorage, + ) + +from .kg.neo4j_impl import Neo4JStorage + +from .kg.oracle_impl import ( + OracleKVStorage, + OracleGraphStorage, + OracleVectorDBStorage + ) + +# future KG integrations + +# from .kg.ArangoDB_impl import ( +# GraphStorage as ArangoDBStorage +# ) + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: return asyncio.get_event_loop() @@ -68,7 +74,9 @@ class LightRAG: default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) - kg: str = field(default="NetworkXStorage") + kv_storage : str = field(default="JsonKVStorage") + vector_storage: str = field(default="NanoVectorDBStorage") + graph_storage: str = field(default="NetworkXStorage") current_log_level = logger.level log_level: str = field(default=current_log_level) @@ -108,9 +116,16 @@ class LightRAG: llm_model_kwargs: dict = field(default_factory=dict) # storage - key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage - vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage + vector_db_storage_cls_kwargs: dict = field(default_factory=dict) + # if DATABASE_TYPE is None: + # key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage + # vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage + # vector_db_storage_cls_kwargs: dict = field(default_factory=dict) + # elif DATABASE_TYPE == "oracle": + # key_string_value_json_storage_cls: Type[BaseKVStorage] = OracleKVStorage, + # vector_db_storage_cls: Type[BaseVectorStorage] = OracleVectorDBStorage, + enable_llm_cache: bool = True # extension @@ -128,21 +143,16 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # @TODO: should move all storage setup here to leverage initial start params attached to self. - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ - self.kg - ] + self. key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage] + + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage] + + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage] if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.full_docs = self.key_string_value_json_storage_cls( - namespace="full_docs", global_config=asdict(self) - ) - - self.text_chunks = self.key_string_value_json_storage_cls( - namespace="text_chunks", global_config=asdict(self) - ) self.llm_response_cache = ( self.key_string_value_json_storage_cls( @@ -151,14 +161,27 @@ class LightRAG: if self.enable_llm_cache else None ) - self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace="chunk_entity_relation", global_config=asdict(self) - ) self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) + #### + # add embedding func by walter + #### + self.full_docs = self.key_string_value_json_storage_cls( + namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func + ) + self.text_chunks = self.key_string_value_json_storage_cls( + namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func + ) + self.chunk_entity_relation_graph = self.graph_storage_cls( + namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func + ) + #### + # add embedding func by walter over + #### + self.entities_vdb = self.vector_db_storage_cls( namespace="entities", global_config=asdict(self), @@ -187,8 +210,15 @@ class LightRAG: def _get_storage_class(self) -> Type[BaseGraphStorage]: return { + "JsonKVStorage":JsonKVStorage, + "OracleKVStorage":OracleKVStorage, + + "NanoVectorDBStorage":NanoVectorDBStorage, + "OracleVectorDBStorage":OracleVectorDBStorage, + "Neo4JStorage": Neo4JStorage, "NetworkXStorage": NetworkXStorage, + "OracleGraphStorage": OracleGraphStorage, # "ArangoDBStorage": ArangoDBStorage } diff --git a/lightrag/prompt.py b/lightrag/prompt.py index ba2516d8..e0713859 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -222,14 +222,24 @@ Output: """ -PROMPTS["naive_rag_response"] = """You're a helpful assistant -Below are the knowledge you know: -{content_data} ---- -If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up. +PROMPTS["naive_rag_response"] = """---Role--- + +You are a helpful assistant responding to questions about documents provided. + + +---Goal--- + Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. If you don't know the answer, just say so. Do not make anything up. Do not include information where the supporting evidence for it is not provided. + ---Target response length and format--- + {response_type} + +---Documents--- + +{content_data} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ diff --git a/requirements.txt b/requirements.txt index 8620fe10..890659cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ torch transformers xxhash # lmdeploy[all] +oracledb From 7fadd914bbb480557199344a929e04da35ee8fad Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:02:29 +0800 Subject: [PATCH 02/14] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b726f605..11b55ed6 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News +- [x] [2024.11.08]🎯📢You can [use Oracle Database 23ai for Storage](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now. - [x] [2024.11.04]🎯📢You can [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage) now. - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`. - [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization. From b690e071bff770420ec78c07f2b93390b1787c9a Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:12:58 +0800 Subject: [PATCH 03/14] support Oracle Database storage --- README.md | 2 +- examples/lightrag_oracle_demo.py | 2 +- lightrag/lightrag.py | 19 ++++++------------- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 11b55ed6..d4893769 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News -- [x] [2024.11.08]🎯📢You can [use Oracle Database 23ai for Storage](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now. +- [x] [2024.11.08]🎯📢You can [use Oracle Database 23ai for all storage types (kv/vector/graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now. - [x] [2024.11.04]🎯📢You can [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage) now. - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`. - [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization. diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 94a47965..93f0799d 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -116,7 +116,7 @@ async def main(): modes = ["naive", "local", "global", "hybrid"] for mode in modes: print("="*20, mode, "="*20) - print(await rag.aquery("这个文章讲了什么?", param=QueryParam(mode=mode))) + print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode))) print("-"*100, "\n") except Exception as e: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d6a82d71..820051b7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -34,7 +34,6 @@ from .base import ( QueryParam, ) - from .storage import ( JsonKVStorage, NanoVectorDBStorage, @@ -116,15 +115,7 @@ class LightRAG: llm_model_kwargs: dict = field(default_factory=dict) # storage - vector_db_storage_cls_kwargs: dict = field(default_factory=dict) - # if DATABASE_TYPE is None: - # key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage - # vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage - # vector_db_storage_cls_kwargs: dict = field(default_factory=dict) - # elif DATABASE_TYPE == "oracle": - # key_string_value_json_storage_cls: Type[BaseKVStorage] = OracleKVStorage, - # vector_db_storage_cls: Type[BaseVectorStorage] = OracleVectorDBStorage, enable_llm_cache: bool = True @@ -144,11 +135,10 @@ class LightRAG: # @TODO: should move all storage setup here to leverage initial start params attached to self. - self. key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage] - + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage] + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage] self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage] - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage] if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -210,14 +200,17 @@ class LightRAG: def _get_storage_class(self) -> Type[BaseGraphStorage]: return { + # kv storage "JsonKVStorage":JsonKVStorage, "OracleKVStorage":OracleKVStorage, + # vector storage "NanoVectorDBStorage":NanoVectorDBStorage, "OracleVectorDBStorage":OracleVectorDBStorage, - "Neo4JStorage": Neo4JStorage, + # graph storage "NetworkXStorage": NetworkXStorage, + "Neo4JStorage": Neo4JStorage, "OracleGraphStorage": OracleGraphStorage, # "ArangoDBStorage": ArangoDBStorage } From 6e56c8343ada2836c383841f523db191cca38da1 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:20:34 +0800 Subject: [PATCH 04/14] support Oracle DB --- examples/lightrag_oracle_demo.py | 4 ++-- lightrag/lightrag.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 93f0799d..166f2b9a 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -71,13 +71,13 @@ async def main(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud oracle_db = OracleDB(config={ - "user":"RAG", + "user":"username", "password":"xxxxxxxxx", "dsn":"xxxxxxx_medium", "config_dir":"dir/path/to/oracle/config", "wallet_location":"dir/path/to/oracle/wallet", "wallet_password":"xxxxxxxxx", - "workspace":"company" # specify which docs we want to store and query + "workspace":"company" # specify which docs you want to store and query } ) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 820051b7..3a426a22 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -146,7 +146,7 @@ class LightRAG: self.llm_response_cache = ( self.key_string_value_json_storage_cls( - namespace="llm_response_cache", global_config=asdict(self) + namespace="llm_response_cache", global_config=asdict(self),embedding_func=None ) if self.enable_llm_cache else None From 07c02385c28d6d13655dd98b4e9bfd356f6fe6f4 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:27:15 +0800 Subject: [PATCH 05/14] support Oralce DB --- requirements.txt | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 890659cf..d57de091 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,22 @@ accelerate -aioboto3 aiohttp -graspologic -hnswlib -nano-vectordb -neo4j -networkx -ollama -openai pyvis tenacity +xxhash +# lmdeploy[all] + +# LLM packages tiktoken torch transformers -xxhash -# lmdeploy[all] +aioboto3 +ollama +openai + +# database packages +graspologic +hnswlib +networkx oracledb +nano-vectordb +neo4j \ No newline at end of file From a8ec12efe0c92c4f187f808e1ac901501193f614 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:19:42 +0800 Subject: [PATCH 06/14] fix bug --- .gitignore | 1 + lightrag/base.py | 2 +- lightrag/operate.py | 10 +++++++--- lightrag/prompt.py | 2 +- test.py | 2 +- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index def738b2..65aaaa02 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ local_neo4jWorkDir/ neo4jWorkDir/ ignore_this.txt .venv/ +*.ignore.* diff --git a/lightrag/base.py b/lightrag/base.py index 97524472..b88acae2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -84,7 +84,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc + embedding_func: EmbeddingFunc = None async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/lightrag/operate.py b/lightrag/operate.py index 04725d6a..3fcc80c8 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -16,6 +16,7 @@ from .utils import ( split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, + locate_json_string_body_from_string ) from .base import ( BaseGraphStorage, @@ -403,9 +404,10 @@ async def local_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) + json_text = locate_json_string_body_from_string(result) try: - keywords_data = json.loads(result) + keywords_data = json.loads(json_text) keywords = keywords_data.get("low_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: @@ -670,9 +672,10 @@ async def global_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) + json_text = locate_json_string_body_from_string(result) try: - keywords_data = json.loads(result) + keywords_data = json.loads(json_text) keywords = keywords_data.get("high_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: @@ -911,8 +914,9 @@ async def hybrid_query( kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) + json_text = locate_json_string_body_from_string(result) try: - keywords_data = json.loads(result) + keywords_data = json.loads(json_text) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) diff --git a/lightrag/prompt.py b/lightrag/prompt.py index e0713859..5de116b3 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -14,7 +14,7 @@ Given a text document that is potentially relevant to this activity and a list o -Steps- 1. Identify all entities. For each identified entity, extract the following information: -- entity_name: Name of the entity, capitalized +- entity_name: Name of the entity, use same language as input text. If English, capitalized the name. - entity_type: One of the following types: [{entity_types}] - entity_description: Comprehensive description of the entity's attributes and activities Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} diff --git a/test.py b/test.py index 35c03afe..c5d7fec0 100644 --- a/test.py +++ b/test.py @@ -18,7 +18,7 @@ rag = LightRAG( # llm_model_func=gpt_4o_complete # Optionally, use a stronger model ) -with open("./book.txt") as f: +with open("./dickens/book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search From f905305eed5292c6758f5423b5254347ffe12814 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:24:31 +0800 Subject: [PATCH 07/14] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index ab6312a7..2980de6f 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -279,7 +279,7 @@ class OracleVectorDBStorage(BaseVectorStorage): pass - #################### query method ################ + #################### query method ############### async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: """从向量数据库中查询数据""" embeddings = await self.embedding_func([query]) From ff0c763dbdbf7e78ccce2ba1e740abd462277ce6 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:32:30 +0800 Subject: [PATCH 08/14] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2980de6f..1d8b5002 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -413,7 +413,7 @@ class OracleGraphStorage(BaseGraphStorage): """写入graphhml图文件""" logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!") - #################### query method ################ + #################### query method ################# async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id) From c89bf79ee02dac1f29eafd49cdc6a5b87cc27051 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:37:24 +0800 Subject: [PATCH 09/14] Create lightrag_api_oracle_demo..py --- examples/lightrag_api_oracle_demo..py | 232 ++++++++++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 examples/lightrag_api_oracle_demo..py diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py new file mode 100644 index 00000000..8346b66c --- /dev/null +++ b/examples/lightrag_api_oracle_demo..py @@ -0,0 +1,232 @@ + +from fastapi import FastAPI, HTTPException, File, UploadFile +from contextlib import asynccontextmanager +from pydantic import BaseModel +from typing import Optional + +import sys, os +print(os.getcwd()) +from pathlib import Path +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) + +import asyncio +import nest_asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np +from datetime import datetime + +from lightrag.kg.oracle_impl import OracleDB + + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR = "index_default" + + +# We use OpenAI compatible API to call LLM on Oracle Cloud +# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway +BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" +APIKEY = "ocigenerativeai" + +# Configure working directory +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") +LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") + + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + LLM_MODEL, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=APIKEY, + base_url=BASE_URL, + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model=EMBEDDING_MODEL, + api_key=APIKEY, + base_url=BASE_URL, + ) + + +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + return embedding_dim + +async def init(): + + # Detect embedding dimension + embedding_dimension = await get_embedding_dim() + print(f"Detected embedding dimension: {embedding_dimension}") + # Create Oracle DB connection + # The `config` parameter is the connection configuration of Oracle DB + # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html + # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query + # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud + + + oracle_db = OracleDB(config={ + "user":"", + "password":"", + "dsn":"", + "config_dir":"", + "wallet_location":"", + "wallet_password":"", + "workspace":"" + } # specify which docs you want to store and query + ) + + # Check if Oracle DB tables exist, if not, tables will be created + await oracle_db.check_tables() + # Initialize LightRAG + # We use Oracle DB as the KV/vector/graph storage + rag = LightRAG( + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage = "OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage" + ) + + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool + rag.graph_storage_cls.db = oracle_db + rag.key_string_value_json_storage_cls.db = oracle_db + rag.vector_db_storage_cls.db = oracle_db + + return rag + +# Data models + + +class QueryRequest(BaseModel): + query: str + mode: str = "hybrid" + only_need_context: bool = False + + +class InsertRequest(BaseModel): + text: str + + +class Response(BaseModel): + status: str + data: Optional[str] = None + message: Optional[str] = None + + +# API routes + +rag = None # 定义为全局对象 + +@asynccontextmanager +async def lifespan(app: FastAPI): + global rag + rag = await init() # 在应用启动时初始化 `rag` + print("done!") + yield + + +app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) + +@app.post("/query", response_model=Response) +async def query_endpoint(request: QueryRequest): + try: + # loop = asyncio.get_event_loop() + result = await rag.aquery( + request.query, + param=QueryParam( + mode=request.mode, only_need_context=request.only_need_context + ), + ) + return Response(status="success", data=result) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/insert", response_model=Response) +async def insert_endpoint(request: InsertRequest): + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(request.text)) + return Response(status="success", message="Text inserted successfully") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/insert_file", response_model=Response) +async def insert_file(file: UploadFile = File(...)): + try: + file_content = await file.read() + # Read file content + try: + content = file_content.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, try other encodings + content = file_content.decode("gbk") + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {file.filename} inserted successfully", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8020) + +# Usage example +# To run the server, use the following command in your terminal: +# python lightrag_api_openai_compatible_demo.py + +# Example requests: +# 1. Query: +# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' + +# 2. Insert text: +# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' + +# 3. Insert file: +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' + +# 4. Health check: +# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file From 45156ef0e6b224a72d48e6f69f7d1aae87e6bfb6 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 12 Nov 2024 11:19:34 +0800 Subject: [PATCH 10/14] add oracle support --- examples/lightrag_api_oracle_demo..py | 11 ++++++++--- examples/lightrag_oracle_demo.py | 10 +++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index 8346b66c..d63b4588 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -5,10 +5,7 @@ from pydantic import BaseModel from typing import Optional import sys, os -print(os.getcwd()) from pathlib import Path -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) import asyncio import nest_asyncio @@ -21,6 +18,14 @@ from datetime import datetime from lightrag.kg.oracle_impl import OracleDB +print(os.getcwd()) + +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) + + + + # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 166f2b9a..3e196400 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -1,20 +1,16 @@ - - import sys, os -print(os.getcwd()) from pathlib import Path -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) - import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np from datetime import datetime - from lightrag.kg.oracle_impl import OracleDB +print(os.getcwd()) +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) WORKING_DIR = "./dickens" From 5dce625464a045c75c1c6b5bdf162808c246645c Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 12 Nov 2024 12:02:24 +0800 Subject: [PATCH 11/14] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 73 -------------------------------------- 1 file changed, 73 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 7340ad7d..2de4526c 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -502,79 +502,6 @@ class OracleGraphStorage(BaseGraphStorage): else: #print("Node Edge not exist!",self.db.workspace, source_node_id) return [] - - #################### INSERT method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - """插入或更新节点""" - #print("go into upsert node method") - entity_name = node_id - entity_type = node_data["entity_type"] - description = node_data["description"] - source_id = node_data["source_id"] - content = entity_name+description - contents = [content] - batches = [ - contents[i: i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"].format( - workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id - ) - #print(merge_sql) - await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector]) - #self._graph.add_node(node_id, **node_data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): - """插入或更新边""" - #print("go into upsert edge method") - source_name = source_node_id - target_name = target_node_id - weight = edge_data["weight"] - keywords = edge_data["keywords"] - description = edge_data["description"] - source_chunk_id = edge_data["source_id"] - content = keywords+source_name+target_name+description - contents = [content] - batches = [ - contents[i: i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"].format( - workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id - ) - #print(merge_sql) - await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector]) - #self._graph.add_edge(source_node_id, target_node_id, **edge_data) - - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - """为节点生成向量""" - if algorithm not in self._node_embed_algorithms: - raise ValueError(f"Node embedding algorithm {algorithm} not supported") - return await self._node_embed_algorithms[algorithm]() - - async def _node2vec_embed(self): - """为节点生成向量""" - from graspologic import embed - - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] - return embeddings, nodes_ids N_T = { From 66799fb6857adbd6f2a3eb6d1ca64eaa9b7a1a19 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:03:03 +0800 Subject: [PATCH 12/14] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2de4526c..e40f60c8 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -306,28 +306,6 @@ class OracleVectorDBStorage(BaseVectorStorage): @dataclass class OracleGraphStorage(BaseGraphStorage): """基于Oracle的图存储模块""" - # @staticmethod - # def load_graph(file_name) -> nx.Graph: - # """读取graphhml图文件""" - - # @staticmethod - # def write_graph(graph: nx.Graph, file_name): - # # """写入graphhml图文件""" - - # @staticmethod - # def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: - # """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - # Return the largest connected component of the graph, with nodes and edges sorted in a stable way. - # 用于产生稳定的最大连通分量的模块,即相同的输入图==相同的输出lcc。 - # """ - - - # @staticmethod - # def _stabilize_graph(graph: nx.Graph) -> nx.Graph: - # """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - # Ensure an undirected graph with the same relationships will always be read the same way. - # 确保具有相同关系的无向图始终以相同的方式读取。 - # """ def __post_init__(self): """从graphml文件加载图""" From 90790747ee4cd650d11327bacba54b6dec373e7c Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:23:03 +0800 Subject: [PATCH 13/14] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index e40f60c8..e0671a71 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -680,4 +680,4 @@ SQL_TEMPLATES = { WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """ - } + } \ No newline at end of file From 33caba3e127ce4827c853ea86e4dfdfa98f92c25 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:32:40 +0800 Subject: [PATCH 14/14] fix pre commit --- examples/lightrag_api_oracle_demo..py | 79 ++--- examples/lightrag_oracle_demo.py | 49 +-- lightrag/base.py | 2 + lightrag/kg/oracle_impl.py | 447 +++++++++++++------------- lightrag/lightrag.py | 52 +-- lightrag/operate.py | 2 +- requirements.txt | 24 +- 7 files changed, 345 insertions(+), 310 deletions(-) diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index d63b4588..3bfae452 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -1,10 +1,10 @@ - from fastapi import FastAPI, HTTPException, File, UploadFile from contextlib import asynccontextmanager from pydantic import BaseModel from typing import Optional -import sys, os +import sys +import os from pathlib import Path import asyncio @@ -13,7 +13,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np -from datetime import datetime from lightrag.kg.oracle_impl import OracleDB @@ -24,8 +23,6 @@ script_directory = Path(__file__).resolve().parent.parent sys.path.append(os.path.abspath(script_directory)) - - # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() @@ -51,6 +48,7 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -80,8 +78,8 @@ async def get_embedding_dim(): embedding_dim = embedding.shape[1] return embedding_dim + async def init(): - # Detect embedding dimension embedding_dimension = await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") @@ -91,36 +89,36 @@ async def init(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud + oracle_db = OracleDB( + config={ + "user": "", + "password": "", + "dsn": "", + "config_dir": "", + "wallet_location": "", + "wallet_password": "", + "workspace": "", + } # specify which docs you want to store and query + ) - oracle_db = OracleDB(config={ - "user":"", - "password":"", - "dsn":"", - "config_dir":"", - "wallet_location":"", - "wallet_password":"", - "workspace":"" - } # specify which docs you want to store and query - ) - # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage + # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - graph_storage = "OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" - ) + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage", + ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.graph_storage_cls.db = oracle_db @@ -129,6 +127,7 @@ async def init(): return rag + # Data models @@ -152,6 +151,7 @@ class Response(BaseModel): rag = None # 定义为全局对象 + @asynccontextmanager async def lifespan(app: FastAPI): global rag @@ -160,18 +160,21 @@ async def lifespan(app: FastAPI): yield -app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) +app = FastAPI( + title="LightRAG API", description="API for RAG operations", lifespan=lifespan +) + @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): try: # loop = asyncio.get_event_loop() result = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, only_need_context=request.only_need_context - ), - ) + request.query, + param=QueryParam( + mode=request.mode, only_need_context=request.only_need_context + ), + ) return Response(status="success", data=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -234,4 +237,4 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 3e196400..365b6225 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -1,11 +1,11 @@ -import sys, os +import sys +import os from pathlib import Path import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np -from datetime import datetime from lightrag.kg.oracle_impl import OracleDB print(os.getcwd()) @@ -25,6 +25,7 @@ EMBEDMODEL = "cohere.embed-multilingual-v3.0" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -66,22 +67,21 @@ async def main(): # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud - oracle_db = OracleDB(config={ - "user":"username", - "password":"xxxxxxxxx", - "dsn":"xxxxxxx_medium", - "config_dir":"dir/path/to/oracle/config", - "wallet_location":"dir/path/to/oracle/wallet", - "wallet_password":"xxxxxxxxx", - "workspace":"company" # specify which docs you want to store and query + oracle_db = OracleDB( + config={ + "user": "username", + "password": "xxxxxxxxx", + "dsn": "xxxxxxx_medium", + "config_dir": "dir/path/to/oracle/config", + "wallet_location": "dir/path/to/oracle/wallet", + "wallet_password": "xxxxxxxxx", + "workspace": "company", # specify which docs you want to store and query } - ) - + ) # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() - # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( @@ -93,10 +93,10 @@ async def main(): embedding_dim=embedding_dimension, max_token_size=512, func=embedding_func, - ), - graph_storage = "OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage", ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool @@ -106,18 +106,23 @@ async def main(): # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) + await rag.ainsert(f.read()) # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: - print("="*20, mode, "="*20) - print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode))) - print("-"*100, "\n") + print("=" * 20, mode, "=" * 20) + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode=mode), + ) + ) + print("-" * 100, "\n") except Exception as e: print(f"An error occurred: {e}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/lightrag/base.py b/lightrag/base.py index 379efeb3..46dfc800 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -60,6 +60,7 @@ class BaseVectorStorage(StorageNameSpace): @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): embedding_func: EmbeddingFunc + async def all_keys(self) -> list[str]: raise NotImplementedError @@ -85,6 +86,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): embedding_func: EmbeddingFunc = None + async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index e0671a71..96a9e795 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -1,9 +1,9 @@ import asyncio -#import html -#import os + +# import html +# import os from dataclasses import dataclass -from typing import Any, Union, cast -import networkx as nx +from typing import Union import numpy as np import array @@ -16,8 +16,9 @@ from ..base import ( import oracledb + class OracleDB: - def __init__(self,config,**kwargs): + def __init__(self, config, **kwargs): self.host = config.get("host", None) self.port = config.get("port", None) self.user = config.get("user", None) @@ -32,21 +33,21 @@ class OracleDB: logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") if self.user is None or self.password is None: raise ValueError("Missing database user or password in addon_params") - + try: oracledb.defaults.fetch_lobs = False self.pool = oracledb.create_pool_async( - user = self.user, - password = self.password, - dsn = self.dsn, - config_dir = self.config_dir, - wallet_location = self.wallet_location, - wallet_password = self.wallet_password, - min = 1, - max = self.max, - increment = self.increment - ) + user=self.user, + password=self.password, + dsn=self.dsn, + config_dir=self.config_dir, + wallet_location=self.wallet_location, + wallet_password=self.wallet_password, + min=1, + max=self.max, + increment=self.increment, + ) logger.info(f"Connected to Oracle database at {self.dsn}") except Exception as e: logger.error(f"Failed to connect to Oracle database at {self.dsn}") @@ -90,12 +91,14 @@ class OracleDB: arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - + async def check_tables(self): - for k,v in TABLES.items(): + for k, v in TABLES.items(): try: if k.lower() == "lightrag_graph": - await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only") + await self.query( + "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only" + ) else: await self.query("SELECT 1 FROM {k}".format(k=k)) except Exception as e: @@ -108,12 +111,11 @@ class OracleDB: except Exception as e: logger.error(f"Failed to create table {k} in Oracle database") logger.error(f"Oracle database error: {e}") - - logger.info(f"Finished check all tables in Oracle database") - - - async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]: - async with self.pool.acquire() as connection: + + logger.info("Finished check all tables in Oracle database") + + async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: + async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler with connection.cursor() as cursor: @@ -136,9 +138,9 @@ class OracleDB: data = dict(zip(columns, row)) else: data = None - return data + return data - async def execute(self,sql: str, data: list = None): + async def execute(self, sql: str, data: list = None): # logger.info("go into OracleDB execute method") try: async with self.pool.acquire() as connection: @@ -148,58 +150,63 @@ class OracleDB: if data is None: await cursor.execute(sql) else: - #print(data) - #print(sql) - await cursor.execute(sql,data) + # print(data) + # print(sql) + await cursor.execute(sql, data) await connection.commit() except Exception as e: - logger.error(f"Oracle database error: {e}") + logger.error(f"Oracle database error: {e}") print(sql) print(data) raise + @dataclass class OracleKVStorage(BaseKVStorage): - # should pass db object to self.db def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] - + self._max_batch_size = self.global_config["embedding_batch_num"] + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" - SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id) - #print("get_by_id:"+SQL) - res = await self.db.query(SQL) + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( + workspace=self.db.workspace, id=id + ) + # print("get_by_id:"+SQL) + res = await self.db.query(SQL) if res: - data = res #{"data":res} - #print (data) + data = res # {"data":res} + # print (data) return data else: return None # Query by id - async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]: + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace, - ids=",".join([f"'{id}'" for id in ids])) - #print("get_by_ids:"+SQL) - res = await self.db.query(SQL,multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) + ) + # print("get_by_ids:"+SQL) + res = await self.db.query(SQL, multirows=True) if res: - data = res # [{"data":i} for i in res] - #print(data) + data = res # [{"data":i} for i in res] + # print(data) return data else: return None - + async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], - workspace=self.db.workspace, - ids=",".join([f"'{k}'" for k in keys])) - res = await self.db.query(SQL,multirows=True) + SQL = SQL_TEMPLATES["filter_keys"].format( + table_name=N_T[self.namespace], + workspace=self.db.workspace, + ids=",".join([f"'{k}'" for k in keys]), + ) + res = await self.db.query(SQL, multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -208,14 +215,13 @@ class OracleKVStorage(BaseKVStorage): exist_keys = [] data = set([s for s in keys if s not in exist_keys]) return data - - + ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) - #print(self._data) - #values = [] + # print(self._data) + # values = [] if self.namespace == "text_chunks": list_data = [ { @@ -226,7 +232,7 @@ class OracleKVStorage(BaseKVStorage): ] contents = [v["content"] for v in data.values()] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -235,42 +241,45 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - #print(list_data) + # print(list_data) for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"].format( - check_id=item["__id__"] - ) + merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) - values = [item["__id__"], item["content"], self.db.workspace, item["tokens"], - item["chunk_order_index"], item["full_doc_id"], item["__vector__"]] - #print(merge_sql) + values = [ + item["__id__"], + item["content"], + self.db.workspace, + item["tokens"], + item["chunk_order_index"], + item["full_doc_id"], + item["__vector__"], + ] + # print(merge_sql) await self.db.execute(merge_sql, values) if self.namespace == "full_docs": for k, v in self._data.items(): - #values.clear() + # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"].format( check_id=k, ) values = [k, self._data[k]["content"], self.db.workspace] - #print(merge_sql) + # print(merge_sql) await self.db.execute(merge_sql, values) return left_data - async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: logger.info("full doc and chunk data had been saved into oracle db!") - @dataclass class OracleVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = 0.2 def __post_init__(self): pass - + async def upsert(self, data: dict[str, dict]): """向向量数据库中插入数据""" pass @@ -278,53 +287,51 @@ class OracleVectorDBStorage(BaseVectorStorage): async def index_done_callback(self): pass - #################### query method ############### async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + """从向量数据库中查询数据""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = ', '.join(map(str, embedding.tolist())) + embedding_string = ", ".join(map(str, embedding.tolist())) SQL = SQL_TEMPLATES[self.namespace].format( - embedding_string=embedding_string, - dimension=dimension, - dtype=dtype, - workspace=self.db.workspace, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) + embedding_string=embedding_string, + dimension=dimension, + dtype=dtype, + workspace=self.db.workspace, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) # print(SQL) results = await self.db.query(SQL, multirows=True) - #print("vector search result:",results) + # print("vector search result:",results) return results @dataclass -class OracleGraphStorage(BaseGraphStorage): +class OracleGraphStorage(BaseGraphStorage): """基于Oracle的图存储模块""" - + def __post_init__(self): """从graphml文件加载图""" self._max_batch_size = self.global_config["embedding_batch_num"] - #################### insert method ################ - + async def upsert_node(self, node_id: str, node_data: dict[str, str]): """插入或更新节点""" - #print("go into upsert node method") + # print("go into upsert node method") entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] - source_id = node_data["source_id"] - content = entity_name+description + source_id = node_data["source_id"] + content = entity_name + description contents = [content] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -333,27 +340,38 @@ class OracleGraphStorage(BaseGraphStorage): embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] merge_sql = SQL_TEMPLATES["merge_node"].format( - workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id + workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id ) - #print(merge_sql) - await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector]) - #self._graph.add_node(node_id, **node_data) + # print(merge_sql) + await self.db.execute( + merge_sql, + [ + self.db.workspace, + entity_name, + entity_type, + description, + source_id, + content, + content_vector, + ], + ) + # self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): """插入或更新边""" - #print("go into upsert edge method") + # print("go into upsert edge method") source_name = source_node_id target_name = target_node_id weight = edge_data["weight"] keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] - content = keywords+source_name+target_name+description + content = keywords + source_name + target_name + description contents = [content] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -362,11 +380,27 @@ class OracleGraphStorage(BaseGraphStorage): embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] merge_sql = SQL_TEMPLATES["merge_edge"].format( - workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id + workspace=self.db.workspace, + source_name=source_name, + target_name=target_name, + source_chunk_id=source_chunk_id, ) - #print(merge_sql) - await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector]) - #self._graph.add_edge(source_node_id, target_node_id, **edge_data) + # print(merge_sql) + await self.db.execute( + merge_sql, + [ + self.db.workspace, + source_name, + target_name, + weight, + keywords, + description, + source_chunk_id, + content, + content_vector, + ], + ) + # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: """为节点生成向量""" @@ -386,99 +420,109 @@ class OracleGraphStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - async def index_done_callback(self): """写入graphhml图文件""" - logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!") - + logger.info( + "Node and edge data had been saved into oracle db already, so nothing to do here!" + ) + #################### query method ################# async def has_node(self, node_id: str) -> bool: - """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id) - # print(SQL) - #print(self.db.workspace, node_id) + """根据节点id检查节点是否存在""" + SQL = SQL_TEMPLATES["has_node"].format( + workspace=self.db.workspace, node_id=node_id + ) + # print(SQL) + # print(self.db.workspace, node_id) res = await self.db.query(SQL) if res: - #print("Node exist!",res) + # print("Node exist!",res) return True else: - #print("Node not exist!") + # print("Node not exist!") return False async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """根据源和目标节点id检查边是否存在""" - SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id) + SQL = SQL_TEMPLATES["has_edge"].format( + workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id, + ) # print(SQL) res = await self.db.query(SQL) if res: - #print("Edge exist!",res) + # print("Edge exist!",res) return True else: - #print("Edge not exist!") + # print("Edge not exist!") return False async def node_degree(self, node_id: str) -> int: - """根据节点id获取节点的度""" - SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id) + """根据节点id获取节点的度""" + SQL = SQL_TEMPLATES["node_degree"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(SQL) res = await self.db.query(SQL) if res: - #print("Node degree",res["degree"]) + # print("Node degree",res["degree"]) return res["degree"] else: - #print("Edge not exist!") + # print("Edge not exist!") return 0 - async def edge_degree(self, src_id: str, tgt_id: str) -> int: """根据源和目标节点id获取边的度""" degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) - #print("Edge degree",degree) + # print("Edge degree",degree) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id) + SQL = SQL_TEMPLATES["get_node"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(self.db.workspace, node_id) # print(SQL) res = await self.db.query(SQL) if res: - #print("Get node!",self.db.workspace, node_id,res) + # print("Get node!",self.db.workspace, node_id,res) return res else: - #print("Can't get node!",self.db.workspace, node_id) + # print("Can't get node!",self.db.workspace, node_id) return None - + async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """根据源和目标节点id获取边""" - SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id) + SQL = SQL_TEMPLATES["get_edge"].format( + workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id, + ) res = await self.db.query(SQL) if res: - #print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) + # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res else: - #print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) + # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) return None async def get_node_edges(self, source_node_id: str): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace, - source_node_id=source_node_id) + SQL = SQL_TEMPLATES["get_node_edges"].format( + workspace=self.db.workspace, source_node_id=source_node_id + ) res = await self.db.query(sql=SQL, multirows=True) if res: - data = [(i["source_name"],i["target_name"]) for i in res] - #print("Get node edge!",self.db.workspace, source_node_id,data) + data = [(i["source_name"], i["target_name"]) for i in res] + # print("Get node edge!",self.db.workspace, source_node_id,data) return data else: - #print("Node Edge not exist!",self.db.workspace, source_node_id) + # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] @@ -487,12 +531,12 @@ N_T = { "text_chunks": "LIGHTRAG_DOC_CHUNKS", "chunks": "LIGHTRAG_DOC_CHUNKS", "entities": "LIGHTRAG_GRAPH_NODES", - "relationships": "LIGHTRAG_GRAPH_EDGES" + "relationships": "LIGHTRAG_GRAPH_EDGES", } TABLES = { - "LIGHTRAG_DOC_FULL": - {"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL ( + "LIGHTRAG_DOC_FULL": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( id varchar(256)PRIMARY KEY, workspace varchar(1024), doc_name varchar(1024), @@ -500,61 +544,63 @@ TABLES = { meta JSON, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - - "LIGHTRAG_DOC_CHUNKS": - {"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + )""" + }, + "LIGHTRAG_DOC_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( id varchar(256) PRIMARY KEY, workspace varchar(1024), full_doc_id varchar(256), chunk_order_index NUMBER, - tokens NUMBER, + tokens NUMBER, content CLOB, content_vector VECTOR, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL - )"""}, - - "LIGHTRAG_GRAPH_NODES": - {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES ( + updatetime TIMESTAMP DEFAULT NULL + )""" + }, + "LIGHTRAG_GRAPH_NODES": { + "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES ( id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, workspace varchar(1024), name varchar(2048), - entity_type varchar(1024), + entity_type varchar(1024), description CLOB, source_chunk_id varchar(256), content CLOB, content_vector VECTOR, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - "LIGHTRAG_GRAPH_EDGES": - {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES ( + )""" + }, + "LIGHTRAG_GRAPH_EDGES": { + "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES ( id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, workspace varchar(1024), source_name varchar(2048), - target_name varchar(2048), + target_name varchar(2048), weight NUMBER, - keywords CLOB, + keywords CLOB, description CLOB, source_chunk_id varchar(256), content CLOB, content_vector VECTOR, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - "LIGHTRAG_LLM_CACHE": - {"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE ( + )""" + }, + "LIGHTRAG_LLM_CACHE": { + "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( id varchar(256) PRIMARY KEY, send clob, return clob, model varchar(1024), createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL - )"""}, - - "LIGHTRAG_GRAPH": - {"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph + )""" + }, + "LIGHTRAG_GRAPH": { + "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph VERTEX TABLES ( lightrag_graph_nodes KEY (id) LABEL entity @@ -565,93 +611,67 @@ TABLES = { SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name) DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name) LABEL has_relation - PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) - ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""}, - } + PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) + ) OPTIONS(ALLOW MIXED PROPERTY TYPES)""" + }, +} SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": - "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", - - "get_by_id_text_chunks": - "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", - - "get_by_ids_full_docs": - "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", - - "get_by_ids_text_chunks": - "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", - - "filter_keys": - "select id from {table_name} where workspace='{workspace}' and id in ({ids})", - + "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})", "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = '{check_id}') WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:1,:2,:3) """, - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a USING DUAL ON (a.id = '{check_id}') WHEN NOT MATCHED THEN INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) values (:1,:2,:3,:4,:5,:6,:7) """, - # SQL for VectorStorage - "entities": - """SELECT name as entity_name FROM - (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') + "entities": """SELECT name as entity_name FROM + (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", - - "relationships": - """SELECT source_name as src_id, target_name as tgt_id FROM - (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') + "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM + (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", - - "chunks": - """SELECT id FROM - (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') + "chunks": """SELECT id FROM + (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", - # SQL for GraphStorage - "has_node": - """SELECT * FROM GRAPH_TABLE (lightrag_graph + "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) WHERE a.workspace='{workspace}' AND a.name='{node_id}' COLUMNS (a.name))""", - - "has_edge": - """SELECT * FROM GRAPH_TABLE (lightrag_graph + "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) -[e]-> (b) WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' AND a.name='{source_node_id}' AND b.name='{target_node_id}' COLUMNS (e.source_name,e.target_name) )""", - - "node_degree": - """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph + "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' AND a.name='{node_id}' or b.name = '{node_id}' COLUMNS (a.name))""", - - "get_node": - """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description + "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description FROM GRAPH_TABLE (lightrag_graph - MATCH (a) + MATCH (a) WHERE a.workspace='{workspace}' AND a.name='{node_id}' COLUMNS (a.name) ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name WHERE t2.workspace='{workspace}'""", - - "get_edge": - """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, + "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) @@ -659,15 +679,12 @@ SQL_TEMPLATES = { AND a.name='{source_node_id}' and b.name = '{target_node_id}' COLUMNS (e.id,a.name as source_id) ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", - - "get_node_edges": - """SELECT source_name,target_name + "get_node_edges": """SELECT source_name,target_name FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' AND a.name='{source_node_id}' COLUMNS (a.name as source_name,b.name as target_name))""", - "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') @@ -679,5 +696,5 @@ SQL_TEMPLATES = { ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """ - } \ No newline at end of file + values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, +} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 52786970..50e33405 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -38,15 +38,11 @@ from .storage import ( JsonKVStorage, NanoVectorDBStorage, NetworkXStorage, - ) +) from .kg.neo4j_impl import Neo4JStorage -from .kg.oracle_impl import ( - OracleKVStorage, - OracleGraphStorage, - OracleVectorDBStorage - ) +from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage # future KG integrations @@ -54,6 +50,7 @@ from .kg.oracle_impl import ( # GraphStorage as ArangoDBStorage # ) + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: return asyncio.get_event_loop() @@ -72,7 +69,7 @@ class LightRAG: default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) - kv_storage : str = field(default="JsonKVStorage") + kv_storage: str = field(default="JsonKVStorage") vector_storage: str = field(default="NanoVectorDBStorage") graph_storage: str = field(default="NetworkXStorage") @@ -115,7 +112,7 @@ class LightRAG: # storage vector_db_storage_cls_kwargs: dict = field(default_factory=dict) - + enable_llm_cache: bool = True # extension @@ -134,18 +131,25 @@ class LightRAG: # @TODO: should move all storage setup here to leverage initial start params attached to self. - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage] - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage] - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage] + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class()[self.kv_storage] + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ + self.vector_storage + ] + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ + self.graph_storage + ] if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.llm_response_cache = ( self.key_string_value_json_storage_cls( - namespace="llm_response_cache", global_config=asdict(self),embedding_func=None + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, ) if self.enable_llm_cache else None @@ -159,13 +163,19 @@ class LightRAG: # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( - namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func + namespace="full_docs", + global_config=asdict(self), + embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( - namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func + namespace="text_chunks", + global_config=asdict(self), + embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func + namespace="chunk_entity_relation", + global_config=asdict(self), + embedding_func=self.embedding_func, ) #### # add embedding func by walter over @@ -200,13 +210,11 @@ class LightRAG: def _get_storage_class(self) -> Type[BaseGraphStorage]: return { # kv storage - "JsonKVStorage":JsonKVStorage, - "OracleKVStorage":OracleKVStorage, - + "JsonKVStorage": JsonKVStorage, + "OracleKVStorage": OracleKVStorage, # vector storage - "NanoVectorDBStorage":NanoVectorDBStorage, - "OracleVectorDBStorage":OracleVectorDBStorage, - + "NanoVectorDBStorage": NanoVectorDBStorage, + "OracleVectorDBStorage": OracleVectorDBStorage, # graph storage "NetworkXStorage": NetworkXStorage, "Neo4JStorage": Neo4JStorage, diff --git a/lightrag/operate.py b/lightrag/operate.py index 5288c26d..db7c9401 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -16,7 +16,7 @@ from .utils import ( split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, - locate_json_string_body_from_string + locate_json_string_body_from_string, ) from .base import ( BaseGraphStorage, diff --git a/requirements.txt b/requirements.txt index d57de091..6adb6929 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,22 @@ accelerate +aioboto3 aiohttp + +# database packages +graspologic +hnswlib +nano-vectordb +neo4j +networkx +ollama +openai +oracledb pyvis tenacity -xxhash # lmdeploy[all] # LLM packages tiktoken torch transformers -aioboto3 -ollama -openai - -# database packages -graspologic -hnswlib -networkx -oracledb -nano-vectordb -neo4j \ No newline at end of file +xxhash