mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup: use load_file of safetensors directly for loading ip adapters
This commit is contained in:
parent
1372ef15b3
commit
14a9f74b17
@ -3,12 +3,14 @@
|
|||||||
|
|
||||||
from typing import List, Optional, TypedDict, Union
|
from typing import List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors import safe_open
|
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
@ -208,12 +210,12 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapter
|
|||||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
|
|
||||||
if ip_adapter_ckpt_path.endswith("safetensors"):
|
if ip_adapter_ckpt_path.endswith("safetensors"):
|
||||||
model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt")
|
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||||
for key in model.keys():
|
for key in model.keys():
|
||||||
if key.startswith("image_proj."):
|
if key.startswith("image_proj."):
|
||||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
|
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||||
elif key.startswith("ip_adapter."):
|
elif key.startswith("ip_adapter."):
|
||||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user