cleanup: use load_file of safetensors directly for loading ip adapters

This commit is contained in:
blessedcoolant 2024-04-01 06:37:38 +05:30
parent 1372ef15b3
commit 14a9f74b17

View File

@ -3,12 +3,14 @@
from typing import List, Optional, TypedDict, Union from typing import List, Optional, TypedDict, Union
import safetensors
import safetensors.torch
import torch import torch
from PIL import Image from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.util.devices import choose_torch_device
from ..raw_model import RawModel from ..raw_model import RawModel
from .resampler import Resampler from .resampler import Resampler
@ -208,12 +210,12 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapter
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}} state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if ip_adapter_ckpt_path.endswith("safetensors"): if ip_adapter_ckpt_path.endswith("safetensors"):
model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt") model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys(): for key in model.keys():
if key.startswith("image_proj."): if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key) state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."): elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key) state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else: else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.") raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else: else: