inference running locally. use neo4j next
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@ dickens/
|
||||
book.txt
|
||||
lightrag-dev/
|
||||
.idea/
|
||||
dist/
|
||||
dist/
|
||||
env/
|
@@ -1,5 +1,27 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
print ("init package vars here. ......")
|
||||
from .neo4j import GraphStorage as Neo4JStorage
|
||||
|
||||
|
||||
# import sys
|
||||
# import importlib
|
||||
# # Specify the path to the directory containing the module
|
||||
# # Add the directory to the system path
|
||||
# module_dir = '/Users/kenwiltshire/documents/dev/LightRag/lightrag/kg'
|
||||
# sys.path.append(module_dir)
|
||||
# # Specify the module name
|
||||
# module_name = 'neo4j'
|
||||
# # Import the module
|
||||
# spec = importlib.util.spec_from_file_location(module_name, f'{module_dir}/{module_name}.py')
|
||||
|
||||
# Neo4JStorage = importlib.util.module_from_spec(spec)
|
||||
# spec.loader.exec_module(Neo4JStorage)
|
||||
|
||||
|
||||
|
||||
# Relative imports are still possible by adding a leading period to the module name when using the from ... import form:
|
||||
|
||||
# # Import names from pkg.string
|
||||
# from .string import name1, name2
|
||||
# # Import pkg.string
|
||||
# from . import string
|
||||
|
||||
__version__ = "0.0.7"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -3,11 +3,15 @@ import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
|
||||
from .utils import load_json, logger, write_json
|
||||
|
||||
|
||||
# import package.common.utils as utils
|
||||
|
||||
|
||||
from lightrag.utils import load_json, logger, write_json
|
||||
from ..base import (
|
||||
BaseGraphStorage
|
||||
)
|
||||
@@ -22,10 +26,10 @@ PASSWORD = "your_password"
|
||||
@dataclass
|
||||
class GraphStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
# def load_nx_graph(file_name) -> nx.Graph:
|
||||
# if os.path.exists(file_name):
|
||||
# return nx.read_graphml(file_name)
|
||||
# return None
|
||||
|
||||
def __post_init__(self):
|
||||
# self._graph = preloaded_graph or nx.Graph()
|
||||
@@ -102,7 +106,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
result = session.run(
|
||||
"""MATCH (n1:{node_label1})-[r]-(n2:{node_label2})
|
||||
RETURN count(r) AS degree"""
|
||||
.format(node_label1=node_label1, node_label2=node_label2)
|
||||
.format(entity_name__label_source=entity_name__label_source, entity_name_label_target=entity_name_label_target)
|
||||
)
|
||||
record = result.single()
|
||||
return record["degree"]
|
||||
@@ -263,7 +267,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
with self._driver.session() as session:
|
||||
#Define the Cypher query
|
||||
options = self.global_config["node2vec_params"]
|
||||
query = f"""CALL gds.node2vec.stream('myGraph', {**options})
|
||||
query = f"""CALL gds.node2vec.stream('myGraph', {options}) # **options
|
||||
YIELD nodeId, embedding
|
||||
RETURN nodeId, embedding"""
|
||||
# Run the query and process the results
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import importlib
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
@@ -23,6 +24,11 @@ from .storage import (
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
)
|
||||
|
||||
from .kg.neo4j import (
|
||||
GraphStorage as Neo4JStorage
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
compute_mdhash_id,
|
||||
@@ -93,7 +99,14 @@ class LightRAG:
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
|
||||
# module = importlib.import_module('kg.neo4j')
|
||||
# Neo4JStorage = getattr(module, 'GraphStorage')
|
||||
|
||||
if True==False:
|
||||
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
|
||||
else:
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
enable_llm_cache: bool = True
|
||||
|
||||
# extension
|
||||
|
@@ -72,7 +72,9 @@ async def openai_complete_if_cache(
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
#kw_
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def azure_openai_complete_if_cache(model,
|
||||
|
@@ -6,8 +6,6 @@ from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
from kg.neo4j import GraphStorage
|
||||
|
||||
|
||||
from .utils import load_json, logger, write_json
|
||||
from .base import (
|
||||
@@ -99,66 +97,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class PineConeVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
import os
|
||||
from pinecone import Pinecone
|
||||
|
||||
pc = Pinecone() #api_key=os.environ.get('PINECONE_API_KEY'))
|
||||
# From here on, everything is identical to the REST-based SDK.
|
||||
self._client = pc.Index(host=self._client_pinecone_host)#'my-index-8833ca1.svc.us-east1-gcp.pinecone.io')
|
||||
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
# self._client.upsert(vectors=[]) pinecone
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
# self._client.query(vector=[...], top_key=10) pinecone
|
||||
results = self._client.query(
|
||||
vector=embedding,
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold, ???
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
@@ -166,8 +112,7 @@ class PineConeVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("self._client.save()")
|
||||
# self._client.save()
|
||||
self._client.save()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -298,5 +243,3 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
|
||||
|
@@ -12,4 +12,5 @@ torch
|
||||
transformers
|
||||
xxhash
|
||||
pyvis
|
||||
aiohttp
|
||||
aiohttp
|
||||
neo4j
|
||||
|
36
test.py
Normal file
36
test.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
|
||||
#########
|
||||
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
#########
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
|
||||
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
||||
)
|
||||
|
||||
with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
|
||||
# Perform local search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
|
||||
# Perform global search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
|
||||
# Perform hybrid search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
Reference in New Issue
Block a user