first commit
This commit is contained in:
293
lightrag/lightrag.py
Normal file
293
lightrag/lightrag.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast
|
||||
|
||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
local_query,
|
||||
global_query,
|
||||
hybird_query,
|
||||
naive_query,
|
||||
)
|
||||
|
||||
from .storage import (
|
||||
JsonKVStorage,
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
)
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
compute_mdhash_id,
|
||||
limit_async_func_call,
|
||||
convert_response_to_json,
|
||||
logger,
|
||||
set_logger,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
StorageNameSpace,
|
||||
QueryParam,
|
||||
)
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
# If there is already an event loop, use it.
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# If in a sub-thread, create a new event loop.
|
||||
logger.info("Creating a new event loop in a sub-thread.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
@dataclass
|
||||
class LightRAG:
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
|
||||
# text chunking
|
||||
chunk_token_size: int = 1200
|
||||
chunk_overlap_token_size: int = 100
|
||||
tiktoken_model_name: str = "gpt-4o-mini"
|
||||
|
||||
# entity extraction
|
||||
entity_extract_max_gleaning: int = 1
|
||||
entity_summary_to_max_tokens: int = 500
|
||||
|
||||
# node embedding
|
||||
node_embedding_algorithm: str = "node2vec"
|
||||
node2vec_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"dimensions": 1536,
|
||||
"num_walks": 10,
|
||||
"walk_length": 40,
|
||||
"num_walks": 10,
|
||||
"window_size": 2,
|
||||
"iterations": 3,
|
||||
"random_seed": 3,
|
||||
}
|
||||
)
|
||||
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
|
||||
# storage
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
enable_llm_cache: bool = True
|
||||
|
||||
# extension
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
|
||||
def __post_init__(self):
|
||||
# use proxy
|
||||
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
|
||||
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
|
||||
|
||||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
self.full_docs = self.key_string_value_json_storage_cls(
|
||||
namespace="full_docs", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||
namespace="text_chunks", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.llm_response_cache = (
|
||||
self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache", global_config=asdict(self)
|
||||
)
|
||||
if self.enable_llm_cache
|
||||
else None
|
||||
)
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||||
)
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func
|
||||
)
|
||||
self.entities_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"}
|
||||
)
|
||||
)
|
||||
self.relationships_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="relationships",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"}
|
||||
)
|
||||
)
|
||||
self.chunks_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
)
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
try:
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
# ---------- new docs
|
||||
new_docs = {
|
||||
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
||||
for c in string_or_strings
|
||||
}
|
||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
if not len(new_docs):
|
||||
logger.warning(f"All docs are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
# ---------- chunking
|
||||
inserting_chunks = {}
|
||||
for doc_key, doc in new_docs.items():
|
||||
chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_key,
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
doc["content"],
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
inserting_chunks.update(chunks)
|
||||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
||||
list(inserting_chunks.keys())
|
||||
)
|
||||
inserting_chunks = {
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
if not len(inserting_chunks):
|
||||
logger.warning(f"All chunks are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
|
||||
await self.chunks_vdb.upsert(inserting_chunks)
|
||||
|
||||
# ---------- commit upsertings and indexing
|
||||
await self.full_docs.upsert(new_docs)
|
||||
await self.text_chunks.upsert(inserting_chunks)
|
||||
finally:
|
||||
await self._insert_done()
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def query(self, query: str, param: QueryParam = QueryParam()):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, param))
|
||||
|
||||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
||||
if param.mode == "local":
|
||||
response = await local_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "global":
|
||||
response = await global_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "hybird":
|
||||
response = await hybird_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
response = await naive_query(
|
||||
query,
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
|
||||
async def _query_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [self.llm_response_cache]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
Reference in New Issue
Block a user