Improve robustness of check for IPAdapter vs IPAdapterPlus.

This commit is contained in:
Ryan Dick 2023-09-14 15:24:47 -04:00
parent 781e8521d5
commit a22c8cb3a1
2 changed files with 27 additions and 15 deletions

View File

@ -1,8 +1,9 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import os
from contextlib import contextmanager
from typing import Optional
from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
@ -67,7 +68,7 @@ class IPAdapter:
def __init__(
self,
ip_adapter_ckpt_path: str,
state_dict: dict[torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
@ -75,12 +76,11 @@ class IPAdapter:
self.device = device
self.dtype = dtype
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
self._num_tokens = num_tokens
self._clip_image_processor = CLIPImageProcessor()
self._state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
self._state_dict = state_dict
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
@ -198,3 +198,18 @@ class IPAdapterPlus(IPAdapter):
]
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
# contains.
is_plus = "proj.weight" not in state_dict["image_proj"]
if is_plus:
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
else:
return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@ -5,7 +5,11 @@ from typing import Any, Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.ip_adapter.ip_adapter import (
IPAdapter,
IPAdapterPlus,
build_ip_adapter,
)
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
@ -79,16 +83,9 @@ class IPAdapterModel(ModelBase):
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
# TODO(ryand): Checking for "plus" in the file path is fragile. It should be possible to infer whether this is a
# "plus" variant by loading the state_dict.
if "plus" in str(self.model_path):
return IPAdapterPlus(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
else:
return IPAdapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
return build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
@classmethod
def convert_if_required(