fix: use Path for ip_adapter_ckpt_path instead of str

This commit is contained in:
blessedcoolant 2024-04-03 20:21:03 +05:30
parent 14a9f74b17
commit 2dcbb7223b
2 changed files with 6 additions and 6 deletions

View File

@ -1,6 +1,7 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import pathlib
from typing import List, Optional, TypedDict, Union
import safetensors
@ -10,7 +11,6 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.util.devices import choose_torch_device
from ..raw_model import RawModel
from .resampler import Resampler
@ -206,10 +206,10 @@ class IPAdapterPlusXL(IPAdapterPlus):
).to(self.device, dtype=self.dtype)
def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapterStateDict:
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if ip_adapter_ckpt_path.endswith("safetensors"):
if ip_adapter_ckpt_path.stem == "safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("image_proj."):
@ -219,14 +219,14 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapter
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
return state_dict
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)

View File

@ -26,7 +26,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
raise ValueError("There are no submodels in an IP-Adapter model.")
model_path = Path(config.path)
model: RawModel = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path),
ip_adapter_ckpt_path=model_path,
device=torch.device("cpu"),
dtype=self._torch_dtype,
)