Merge branch 'main' into main

This commit is contained in:
zrguo
2024-11-07 14:54:15 +08:00
committed by GitHub
12 changed files with 225 additions and 4803 deletions

View File

@@ -498,6 +498,10 @@ pip install fastapi uvicorn pydantic
2. Set up your environment variables:
```bash
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
export OPENAI_BASE_URL="Your OpenAI API base URL" # Optional: Defaults to "https://api.openai.com/v1"
export OPENAI_API_KEY="Your OpenAI API key" # Required
export LLM_MODEL="Your LLM model" # Optional: Defaults to "gpt-4o-mini"
export EMBEDDING_MODEL="Your embedding model" # Optional: Defaults to "text-embedding-3-large"
```
3. Run the API server:
@@ -522,7 +526,8 @@ The API server provides the following endpoints:
```json
{
"query": "Your question here",
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
"mode": "hybrid", // Can be "naive", "local", "global", or "hybrid"
"only_need_context": true // Optional: Defaults to false, if true, only the referenced context will be returned, otherwise the llm answer will be returned
}
```
- **Example:**

View File

@@ -1,4 +1,4 @@
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
@@ -18,22 +18,28 @@ app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
LLM_MODEL,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs,
)
@@ -44,37 +50,41 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="text-embedding-3-large",
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
model=EMBEDDING_MODEL,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=3072, max_token_size=8192, func=embedding_func
),
embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func),
)
# Data models
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
class InsertRequest(BaseModel):
text: str
class InsertFileRequest(BaseModel):
file_path: str
class Response(BaseModel):
status: str
data: Optional[str] = None
@@ -89,7 +99,8 @@ async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
None, lambda: rag.query(request.query,
param=QueryParam(mode=request.mode, only_need_context=request.only_need_context))
)
return Response(status="success", data=result)
except Exception as e:
@@ -107,30 +118,22 @@ async def insert_endpoint(request: InsertRequest):
@app.post("/insert_file", response_model=Response)
async def insert_file(request: InsertFileRequest):
async def insert_file(file: UploadFile = File(...)):
try:
# Check if file exists
if not os.path.exists(request.file_path):
raise HTTPException(
status_code=404, detail=f"File not found: {request.file_path}"
)
file_content = await file.read()
# Read file content
try:
with open(request.file_path, "r", encoding="utf-8") as f:
content = f.read()
content = file_content.decode("utf-8")
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
with open(request.file_path, "r", encoding="gbk") as f:
content = f.read()
content = file_content.decode("gbk")
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {request.file_path} inserted successfully",
message=f"File content from {file.filename} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,28 +1,34 @@
import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
G = nx.read_graphml("./dickensTestEmbedcall/graph_chunk_entity_relation.graphml")
def get_all_edges_and_nodes(G):
# Get all edges and their properties
edges_with_properties = []
for u, v, data in G.edges(data=True):
edges_with_properties.append({
'start': u,
'end': v,
'label': data.get('label', ''), # Assuming 'label' is used for edge type
'properties': data,
'start_node_properties': G.nodes[u],
'end_node_properties': G.nodes[v]
})
edges_with_properties.append(
{
"start": u,
"end": v,
"label": data.get(
"label", ""
), # Assuming 'label' is used for edge type
"properties": data,
"start_node_properties": G.nodes[u],
"end_node_properties": G.nodes[v],
}
)
return edges_with_properties
# Example usage
if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j
all_edges = get_all_edges_and_nodes(G)
# Print all edges and node properties
for edge in all_edges:
print(f"Edge Label: {edge['label']}")
@@ -31,4 +37,4 @@ if __name__ == "__main__":
print(f"Start Node Properties: {edge['start_node_properties']}")
print(f"End Node: {edge['end']}")
print(f"End Node Properties: {edge['end_node_properties']}")
print("---")
print("---")

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,16 @@
import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast, Tuple, List, Dict
import numpy as np
from typing import Any, Union, Tuple, List, Dict
import inspect
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
from lightrag.utils import logger
from ..base import BaseGraphStorage
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
)
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import (
@@ -26,7 +25,7 @@ from tenacity import (
class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production")
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
@@ -35,7 +34,9 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
def __post_init__(self):
@@ -43,59 +44,54 @@ class Neo4JStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print ("KG successfully indexed.")
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
result = await session.run(query)
async with self._driver.session() as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
self._driver.close()
def close(self):
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
@@ -103,54 +99,51 @@ class Neo4JStorage(BaseGraphStorage):
node = record["n"]
node_dict = dict(node)
logger.debug(
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = await session.run(query)
record = await result.single()
result = await session.run(query)
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
edge_count = record["totalEdgeCount"]
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
)
return edge_count
else:
else:
return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
)
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
)
return degrees
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
@@ -161,28 +154,30 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = await session.run(query)
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
result = await session.run(query)
record = await result.single()
if record:
result = dict(record["edge_properties"])
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"')
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"')
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
@@ -190,26 +185,37 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
async with self._driver.session() as session:
results = await session.run(query)
edges = []
async for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
source_node = record["n"]
connected_node = record["connected"]
source_label = (
list(source_node.labels)[0] if source_node.labels else None
)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
return edges
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
@@ -219,7 +225,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('\"')
label = node_id.strip('"')
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -228,7 +234,9 @@ class Neo4JStorage(BaseGraphStorage):
SET n += $properties
"""
await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
try:
async with self._driver.session() as session:
@@ -236,13 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their labels.
@@ -251,8 +267,8 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
@@ -265,7 +281,9 @@ class Neo4JStorage(BaseGraphStorage):
RETURN r
"""
await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
)
try:
async with self._driver.session() as session:
@@ -273,6 +291,6 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def _node2vec_embed(self):
print ("Implemented but never called.")
print("Implemented but never called.")

View File

@@ -1,6 +1,5 @@
import asyncio
import os
import importlib
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
@@ -24,18 +23,15 @@ from .storage import (
NanoVectorDBStorage,
NetworkXStorage,
)
from .kg.neo4j_impl import (
Neo4JStorage
)
#future KG integrations
from .kg.neo4j_impl import Neo4JStorage
# future KG integrations
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
from .utils import (
EmbeddingFunc,
compute_mdhash_id,
@@ -56,16 +52,18 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
return asyncio.get_event_loop()
except RuntimeError:
logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@dataclass
class LightRAG:
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
@@ -75,8 +73,6 @@ class LightRAG:
current_log_level = logger.level
log_level: str = field(default=current_log_level)
# text chunking
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
@@ -131,8 +127,10 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
#@TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
# @TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg
]
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
@@ -186,6 +184,7 @@ class LightRAG:
**self.llm_model_kwargs,
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
"Neo4JStorage": Neo4JStorage,
@@ -329,4 +328,4 @@ class LightRAG:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await asyncio.gather(*tasks)

View File

@@ -798,4 +798,4 @@ if __name__ == "__main__":
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())
asyncio.run(main())

View File

@@ -466,7 +466,6 @@ async def _build_local_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
@@ -483,7 +482,7 @@ async def _build_local_query_context(
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
@@ -946,7 +945,6 @@ async def hybrid_query(
query_param,
)
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
@@ -957,7 +955,6 @@ async def hybrid_query(
query_param,
)
context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context:
@@ -1026,9 +1023,11 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)
# Combine and deduplicate the relationships
combined_relationships = process_combine_contexts(hl_relationships, ll_relationships)
combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
# Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources)
@@ -1064,7 +1063,6 @@ async def naive_query(
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
@@ -1095,4 +1093,4 @@ async def naive_query(
.strip()
)
return response
return response

View File

@@ -233,8 +233,7 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
#@TODO: NOT USED
# @TODO: NOT USED
async def _node2vec_embed(self):
from graspologic import embed

View File

@@ -9,7 +9,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union,List
from typing import Any, Union, List
import xml.etree.ElementTree as ET
import numpy as np
@@ -176,19 +176,20 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i]
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
@@ -253,13 +254,14 @@ def xml_to_json(xml_file):
print(f"An error occurred: {e}")
return None
def process_combine_contexts(hl, ll):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
if list_hl:
header=list_hl[0]
header = list_hl[0]
list_hl = list_hl[1:]
if list_ll:
header = list_ll[0]
@@ -268,19 +270,17 @@ def process_combine_contexts(hl, ll):
return ""
if list_hl:
list_hl = [','.join(item[1:]) for item in list_hl if item]
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [','.join(item[1:]) for item in list_ll if item]
list_ll = [",".join(item[1:]) for item in list_ll if item]
combined_sources_set = set(
filter(None, list_hl + list_ll)
)
combined_sources_set = set(filter(None, list_hl + list_ll))
combined_sources = [",\t".join(header)]
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
combined_sources = "\n".join(combined_sources)
return combined_sources

23
test.py
View File

@@ -1,11 +1,10 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -15,7 +14,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
@@ -23,13 +22,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -5,8 +5,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./local_neo4jWorkDir"
@@ -18,7 +18,7 @@ rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage",
log_level="INFO"
log_level="INFO",
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
@@ -26,13 +26,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)