Merge pull request #197 from wiltshirek/main

Neo4J integration.
This commit is contained in:
zrguo
2024-11-04 20:49:32 +08:00
committed by GitHub
15 changed files with 5150 additions and 9 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

4
.gitignore vendored
View File

@@ -5,4 +5,8 @@ book.txt
lightrag-dev/ lightrag-dev/
.idea/ .idea/
dist/ dist/
env/
local_neo4jWorkDir/
neo4jWorkDir/
ignore_this.txt
.venv/ .venv/

56
Dockerfile Normal file
View File

@@ -0,0 +1,56 @@
FROM debian:bullseye-slim
ENV JAVA_HOME=/opt/java/openjdk
COPY --from=eclipse-temurin:17 $JAVA_HOME $JAVA_HOME
ENV PATH="${JAVA_HOME}/bin:${PATH}" \
NEO4J_SHA256=7ce97bd9a4348af14df442f00b3dc5085b5983d6f03da643744838c7a1bc8ba7 \
NEO4J_TARBALL=neo4j-enterprise-5.24.2-unix.tar.gz \
NEO4J_EDITION=enterprise \
NEO4J_HOME="/var/lib/neo4j" \
LANG=C.UTF-8
ARG NEO4J_URI=https://dist.neo4j.org/neo4j-enterprise-5.24.2-unix.tar.gz
RUN addgroup --gid 7474 --system neo4j && adduser --uid 7474 --system --no-create-home --home "${NEO4J_HOME}" --ingroup neo4j neo4j
COPY ./local-package/* /startup/
RUN apt update \
&& apt-get install -y curl gcc git jq make procps tini wget \
&& curl --fail --silent --show-error --location --remote-name ${NEO4J_URI} \
&& echo "${NEO4J_SHA256} ${NEO4J_TARBALL}" | sha256sum -c --strict --quiet \
&& tar --extract --file ${NEO4J_TARBALL} --directory /var/lib \
&& mv /var/lib/neo4j-* "${NEO4J_HOME}" \
&& rm ${NEO4J_TARBALL} \
&& sed -i 's/Package Type:.*/Package Type: docker bullseye/' $NEO4J_HOME/packaging_info \
&& mv /startup/neo4j-admin-report.sh "${NEO4J_HOME}"/bin/neo4j-admin-report \
&& mv "${NEO4J_HOME}"/data /data \
&& mv "${NEO4J_HOME}"/logs /logs \
&& chown -R neo4j:neo4j /data \
&& chmod -R 777 /data \
&& chown -R neo4j:neo4j /logs \
&& chmod -R 777 /logs \
&& chown -R neo4j:neo4j "${NEO4J_HOME}" \
&& chmod -R 777 "${NEO4J_HOME}" \
&& chmod -R 755 "${NEO4J_HOME}/bin" \
&& ln -s /data "${NEO4J_HOME}"/data \
&& ln -s /logs "${NEO4J_HOME}"/logs \
&& git clone https://github.com/ncopa/su-exec.git \
&& cd su-exec \
&& git checkout 4c3bb42b093f14da70d8ab924b487ccfbb1397af \
&& echo d6c40440609a23483f12eb6295b5191e94baf08298a856bab6e15b10c3b82891 su-exec.c | sha256sum -c \
&& echo 2a87af245eb125aca9305a0b1025525ac80825590800f047419dc57bba36b334 Makefile | sha256sum -c \
&& make \
&& mv /su-exec/su-exec /usr/bin/su-exec \
&& apt-get -y purge --auto-remove curl gcc git make \
&& rm -rf /var/lib/apt/lists/* /su-exec
ENV PATH "${NEO4J_HOME}"/bin:$PATH
WORKDIR "${NEO4J_HOME}"
VOLUME /data /logs
EXPOSE 7474 7473 7687
ENTRYPOINT ["tini", "-g", "--", "/startup/docker-entrypoint.sh"]
CMD ["neo4j"]

View File

@@ -161,6 +161,39 @@ rag = LightRAG(
``` ```
</details> </details>
<details>
<summary> Using Neo4J for Storage </summary>
* For production level scenarios you will most likely want to leverage an enterprise solution
* for KG storage. Running Neo4J in Docker is recommended for seamless local testing.
* See: https://hub.docker.com/_/neo4j
```python
export NEO4J_URI="neo4j://localhost:7687"
export NEO4J_USERNAME="neo4j"
export NEO4J_PASSWORD="password"
When you launch the project be sure to override the default KG: NetworkS
by specifying kg="Neo4JStorage".
# Note: Default settings use NetworkX
#Initialize LightRAG with Neo4J implementation.
WORKING_DIR = "./local_neo4jWorkDir"
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage", #<-----------override KG default
log_level="DEBUG" #<-----------override log_level default
)
```
see test_neo4j.py for a working example.
</details>
<details> <details>
<summary> Using Ollama Models </summary> <summary> Using Ollama Models </summary>

34
get_all_edges_nx.py Normal file
View File

@@ -0,0 +1,34 @@
import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
def get_all_edges_and_nodes(G):
# Get all edges and their properties
edges_with_properties = []
for u, v, data in G.edges(data=True):
edges_with_properties.append({
'start': u,
'end': v,
'label': data.get('label', ''), # Assuming 'label' is used for edge type
'properties': data,
'start_node_properties': G.nodes[u],
'end_node_properties': G.nodes[v]
})
return edges_with_properties
# Example usage
if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j
all_edges = get_all_edges_and_nodes(G)
# Print all edges and node properties
for edge in all_edges:
print(f"Edge Label: {edge['label']}")
print(f"Edge Properties: {edge['properties']}")
print(f"Start Node: {edge['start']}")
print(f"Start Node Properties: {edge['start_node_properties']}")
print(f"End Node: {edge['end']}")
print(f"End Node Properties: {edge['end_node_properties']}")
print("---")

File diff suppressed because it is too large Load Diff

3
lightrag/kg/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
# print ("init package vars here. ......")

278
lightrag/kg/neo4j_impl.py Normal file
View File

@@ -0,0 +1,278 @@
import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast, Tuple, List, Dict
import numpy as np
import inspect
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
)
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
@dataclass
class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
self._driver = None
self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
return None
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print ("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"')
async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
entity_name_label = node_id.strip('\"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
if record:
node = record["n"]
node_dict = dict(node)
logger.debug(
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"')
async with self._driver.session() as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = await session.run(query)
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
return edge_count
else:
return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
)
return degrees
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
"""
Find all edges between nodes of two given labels
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session() as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = await session.run(query)
record = await result.single()
if record:
result = dict(record["edge_properties"])
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"')
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
"""
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
results = await session.run(query)
edges = []
async for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label:
edges.append((source_label, target_label))
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the Neo4j database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('\"')
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
query = f"""
MERGE (n:`{label}`)
SET n += $properties
"""
await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
"""
Upsert an edge and its properties between two nodes identified by their labels.
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f"""
MATCH (source:`{source_node_label}`)
WITH source
MATCH (target:`{target_node_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += $properties
RETURN r
"""
await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert_edge)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def _node2vec_embed(self):
print ("Implemented but never called.")

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
import importlib
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -23,6 +24,18 @@ from .storage import (
NanoVectorDBStorage, NanoVectorDBStorage,
NetworkXStorage, NetworkXStorage,
) )
from .kg.neo4j_impl import (
Neo4JStorage
)
#future KG integrations
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc,
compute_mdhash_id, compute_mdhash_id,
@@ -44,18 +57,27 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
logger.info("Creating a new event loop in a sub-thread.") logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop() # loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
return loop return loop
@dataclass @dataclass
class LightRAG: class LightRAG:
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
kg: str = field(default="NetworkXStorage")
current_log_level = logger.level
log_level: str = field(default=current_log_level)
# text chunking # text chunking
chunk_token_size: int = 1200 chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = 100
@@ -94,7 +116,6 @@ class LightRAG:
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
enable_llm_cache: bool = True enable_llm_cache: bool = True
# extension # extension
@@ -104,11 +125,16 @@ class LightRAG:
def __post_init__(self): def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log") log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file) set_logger(log_file)
logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
#@TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
@@ -161,6 +187,12 @@ class LightRAG:
**self.llm_model_kwargs, **self.llm_model_kwargs,
) )
) )
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
"Neo4JStorage": Neo4JStorage,
"NetworkXStorage": NetworkXStorage,
# "ArangoDBStorage": ArangoDBStorage
}
def insert(self, string_or_strings): def insert(self, string_or_strings):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
@@ -298,4 +330,4 @@ class LightRAG:
if storage_inst is None: if storage_inst is None:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -798,4 +798,4 @@ if __name__ == "__main__":
result = await gpt_4o_mini_complete("How are you?") result = await gpt_4o_mini_complete("How are you?")
print(result) print(result)
asyncio.run(main()) asyncio.run(main())

View File

@@ -466,7 +466,9 @@ async def _build_local_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return None return None
node_datas = await asyncio.gather( node_datas = await asyncio.gather(
@@ -481,7 +483,7 @@ async def _build_local_query_context(
{**n, "entity_name": k["entity_name"], "rank": d} {**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
] ]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst
) )
@@ -907,7 +909,6 @@ async def hybrid_query(
.strip() .strip()
) )
result = "{" + result.split("{")[1].split("}")[0] + "}" result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result) keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])
@@ -927,6 +928,7 @@ async def hybrid_query(
query_param, query_param,
) )
if hl_keywords: if hl_keywords:
high_level_context = await _build_global_query_context( high_level_context = await _build_global_query_context(
hl_keywords, hl_keywords,
@@ -937,6 +939,7 @@ async def hybrid_query(
query_param, query_param,
) )
context = combine_contexts(high_level_context, low_level_context) context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context: if query_param.only_need_context:
@@ -1043,6 +1046,7 @@ async def naive_query(
chunks_ids = [r["id"] for r in results] chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids) chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size( maybe_trun_chunks = truncate_list_by_token_size(
chunks, chunks,
key=lambda x: x["content"], key=lambda x: x["content"],
@@ -1073,4 +1077,4 @@ async def naive_query(
.strip() .strip()
) )
return response return response

View File

@@ -233,6 +233,8 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
#@TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
from graspologic import embed from graspologic import embed

View File

@@ -4,6 +4,7 @@ aiohttp
graspologic graspologic
hnswlib hnswlib
nano-vectordb nano-vectordb
neo4j
networkx networkx
ollama ollama
openai openai

35
test.py Normal file
View File

@@ -0,0 +1,35 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

38
test_neo4j.py Normal file
View File

@@ -0,0 +1,38 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./local_neo4jWorkDir"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage",
log_level="INFO"
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))