Merge branch 'main' into add-multi-worker-support

This commit is contained in:
yangdx
2025-02-25 11:15:12 +08:00
2 changed files with 6 additions and 9 deletions

View File

@@ -214,12 +214,6 @@ class NetworkXStorage(BaseGraphStorage):
""" """
labels = set() labels = set()
for node in self._graph.nodes(): for node in self._graph.nodes():
# node_data = dict(self._graph.nodes[node])
# if "entity_type" in node_data:
# if isinstance(node_data["entity_type"], list):
# labels.update(node_data["entity_type"])
# else:
# labels.add(node_data["entity_type"])
labels.add(str(node)) # Add node id as a label labels.add(str(node)) # Add node id as a label
# Return sorted list # Return sorted list

View File

@@ -139,11 +139,14 @@ async def hf_model_complete(
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device device = next(embed_model.parameters()).device
input_ids = tokenizer( encoded_texts = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True texts, return_tensors="pt", padding=True, truncation=True
).input_ids.to(device) ).to(device)
with torch.no_grad(): with torch.no_grad():
outputs = embed_model(input_ids) outputs = embed_model(
input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"],
)
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
if embeddings.dtype == torch.bfloat16: if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy() return embeddings.detach().to(torch.float32).cpu().numpy()