mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: use Path for ip_adapter_ckpt_path instead of str
This commit is contained in:
parent
14a9f74b17
commit
2dcbb7223b
@ -1,6 +1,7 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
# and modified as needed
|
# and modified as needed
|
||||||
|
|
||||||
|
import pathlib
|
||||||
from typing import List, Optional, TypedDict, Union
|
from typing import List, Optional, TypedDict, Union
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
@ -10,7 +11,6 @@ from PIL import Image
|
|||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
@ -206,10 +206,10 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapterStateDict:
|
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
|
||||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
|
|
||||||
if ip_adapter_ckpt_path.endswith("safetensors"):
|
if ip_adapter_ckpt_path.stem == "safetensors":
|
||||||
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||||
for key in model.keys():
|
for key in model.keys():
|
||||||
if key.startswith("image_proj."):
|
if key.startswith("image_proj."):
|
||||||
@ -219,14 +219,14 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapter
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||||
else:
|
else:
|
||||||
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
|
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
|
||||||
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def build_ip_adapter(
|
def build_ip_adapter(
|
||||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
|
||||||
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
||||||
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
|||||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
model: RawModel = build_ip_adapter(
|
model: RawModel = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=str(model_path),
|
ip_adapter_ckpt_path=model_path,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user