diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 581a4187..c88d1c59 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -214,12 +214,6 @@ class NetworkXStorage(BaseGraphStorage): """ labels = set() for node in self._graph.nodes(): - # node_data = dict(self._graph.nodes[node]) - # if "entity_type" in node_data: - # if isinstance(node_data["entity_type"], list): - # labels.update(node_data["entity_type"]) - # else: - # labels.add(node_data["entity_type"]) labels.add(str(node)) # Add node id as a label # Return sorted list diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index d678c611..fb5208b0 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -139,11 +139,14 @@ async def hf_model_complete( async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: device = next(embed_model.parameters()).device - input_ids = tokenizer( + encoded_texts = tokenizer( texts, return_tensors="pt", padding=True, truncation=True - ).input_ids.to(device) + ).to(device) with torch.no_grad(): - outputs = embed_model(input_ids) + outputs = embed_model( + input_ids=encoded_texts["input_ids"], + attention_mask=encoded_texts["attention_mask"], + ) embeddings = outputs.last_hidden_state.mean(dim=1) if embeddings.dtype == torch.bfloat16: return embeddings.detach().to(torch.float32).cpu().numpy()