From 14a9f74b1726cbfd11b00bbc5a3cbc963e15df2f Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 1 Apr 2024 06:37:38 +0530 Subject: [PATCH] cleanup: use load_file of safetensors directly for loading ip adapters --- invokeai/backend/ip_adapter/ip_adapter.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 02788c0ba6..920cb3780a 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -3,12 +3,14 @@ from typing import List, Optional, TypedDict, Union +import safetensors +import safetensors.torch import torch from PIL import Image -from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights +from invokeai.backend.util.devices import choose_torch_device from ..raw_model import RawModel from .resampler import Resampler @@ -208,12 +210,12 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapter state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}} if ip_adapter_ckpt_path.endswith("safetensors"): - model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt") + model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device) for key in model.keys(): if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key) + state_dict["image_proj"][key.replace("image_proj.", "")] = model[key] elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key) + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] else: raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.") else: