Merge branch 'main' into main

This commit is contained in:
wiltshirek
2024-11-01 16:50:45 -04:00
committed by GitHub
20 changed files with 932 additions and 183 deletions

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.7"
__version__ = "0.0.8"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -109,6 +109,7 @@ class LightRAG:
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
@@ -179,7 +180,11 @@ class LightRAG:
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache,
**self.llm_model_kwargs,
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
@@ -239,7 +244,7 @@ class LightRAG:
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),

View File

@@ -7,7 +7,13 @@ import aiohttp
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)
import base64
import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
)
return response.choices[0].message.content
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_complete_if_cache(model,
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs):
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
)
return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
@@ -266,10 +282,13 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
@@ -280,8 +299,10 @@ async def ollama_model_if_cache(
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
ollama_client = ollama.AsyncClient()
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -305,6 +326,135 @@ async def ollama_model_if_cache(
return result
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
log_level="WARNING",
)
return lmdeploy_pipe
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -328,8 +478,9 @@ async def gpt_4o_mini_complete(
**kwargs,
)
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
@@ -339,6 +490,7 @@ async def azure_openai_complete(
**kwargs,
)
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -418,9 +570,11 @@ async def azure_openai_embedding(
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
@@ -440,35 +594,28 @@ async def siliconcloud_embedding(
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith('Bearer '):
api_key = 'Bearer ' + api_key
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {
"Authorization": api_key,
"Content-Type": "application/json"
}
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {
"model": model,
"input": truncate_texts,
"encoding_format": "base64"
}
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if 'code' in content:
if "code" in content:
raise ValueError(content)
base64_strings = [item['embedding'] for item in content['data']]
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack('<' + 'f' * n, decode_bytes)
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
@@ -555,14 +702,16 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text)
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +729,20 @@ class Model(BaseModel):
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config:
arbitrary_types_allowed = True
class MultiModel():
class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
@@ -611,26 +766,31 @@ class MultiModel():
)
```
"""
def __init__(self, models: List[Model]):
self._models = models
self._current_model = 0
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
self,
prompt, system_prompt=None, history_messages=[], **kwargs
self, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("model", None) # stop from overwriting the custom model name
next_model = self._next_model()
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
return await next_model.gen_func(
**args
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
)
return await next_model.gen_func(**args)
if __name__ == "__main__":
import asyncio

View File

@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
already_node = await knwoledge_graph_inst.get_node(entity_name)
already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
description=description,
source_id=source_id,
)
await knwoledge_graph_inst.upsert_node(
await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
already_description = []
already_keywords = []
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
await knwoledge_graph_inst.upsert_node(
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
await knwoledge_graph_inst.upsert_edge(
await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
@@ -341,13 +341,13 @@ async def extract_entities(
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
@@ -384,7 +384,7 @@ async def extract_entities(
}
await relationships_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst
return knowledge_graph_inst
async def local_query(

View File

@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {
"nodes": [],
"edges": []
}
data = {"nodes": [], "edges": []}
# Use namespace
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
for node in root.findall('.//node', namespace):
for node in root.findall(".//node", namespace):
node_data = {
"id": node.get('id').strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
for edge in root.findall('.//edge', namespace):
for edge in root.findall(".//edge", namespace):
edge_data = {
"source": edge.get('source').strip('"'),
"target": edge.get('target').strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
}
data["edges"].append(edge_data)