Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.0.4"
|
||||
__version__ = "1.0.5"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -40,14 +40,6 @@ from .storage import (
|
||||
NetworkXStorage,
|
||||
)
|
||||
|
||||
from .kg.neo4j_impl import Neo4JStorage
|
||||
|
||||
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
||||
|
||||
from .kg.milvus_impl import MilvusVectorDBStorge
|
||||
|
||||
from .kg.mongo_impl import MongoKVStorage
|
||||
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
@@ -55,6 +47,30 @@ from .kg.mongo_impl import MongoKVStorage
|
||||
# )
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
"""Lazily import an external module and return a class from it."""
|
||||
|
||||
def import_class():
|
||||
import importlib
|
||||
|
||||
# Import the module using importlib
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Get the class from the module
|
||||
return getattr(module, class_name)
|
||||
|
||||
# Return the import_class function itself, not its result
|
||||
return import_class
|
||||
|
||||
|
||||
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
|
||||
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
|
||||
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
|
||||
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
|
||||
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
||||
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
@@ -68,7 +84,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop._closed:
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
|
@@ -76,11 +76,24 @@ async def openai_complete_if_cache(
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if r"\u" in content:
|
||||
content = content.encode("utf-8").decode("unicode_escape")
|
||||
|
||||
return content
|
||||
if hasattr(response, "__aiter__"):
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
if r"\u" in content:
|
||||
content = content.encode("utf-8").decode("unicode_escape")
|
||||
yield content
|
||||
|
||||
return inner()
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if r"\u" in content:
|
||||
content = content.encode("utf-8").decode("unicode_escape")
|
||||
return content
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -306,7 +319,7 @@ async def ollama_model_if_cache(
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
""" cannot cache stream response """
|
||||
"""cannot cache stream response"""
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
@@ -447,6 +460,22 @@ class GPTKeywordExtractionFormat(BaseModel):
|
||||
low_level_keywords: List[str]
|
||||
|
||||
|
||||
async def openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = "json"
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await openai_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
@@ -890,6 +919,8 @@ class MultiModel:
|
||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("mode", None)
|
||||
next_model = self._next_model()
|
||||
args = dict(
|
||||
prompt=prompt,
|
||||
|
@@ -222,7 +222,7 @@ async def _merge_edges_then_upsert(
|
||||
},
|
||||
)
|
||||
description = await _handle_entity_relation_summary(
|
||||
(src_id, tgt_id), description, global_config
|
||||
f"({src_id}, {tgt_id})", description, global_config
|
||||
)
|
||||
await knowledge_graph_inst.upsert_edge(
|
||||
src_id,
|
||||
@@ -572,7 +572,6 @@ async def kg_query(
|
||||
mode=query_param.mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
@@ -488,7 +488,7 @@ class CacheData:
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
if hashing_kv is None:
|
||||
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
||||
return
|
||||
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
|
Reference in New Issue
Block a user