From fbe6452c4578d9794d47416092d16eb5d113e809 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 4 Oct 2023 17:55:19 -0400 Subject: [PATCH] Add support for IPAdapterPlusXL based on https://github.com/tencent-ailab/IP-Adapter/commit/6219530507cb4696a0496f10c0e5a4f1dbdc7672. --- invokeai/backend/ip_adapter/ip_adapter.py | 24 ++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 94f3202ba0..4c846ecf00 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -218,6 +218,20 @@ class IPAdapterPlus(IPAdapter): return image_prompt_embeds, uncond_image_prompt_embeds +class IPAdapterPlusXL(IPAdapterPlus): + """IP-Adapter Plus for SDXL.""" + + def _init_image_proj_model(self, state_dict): + return Resampler.from_state_dict( + state_dict=state_dict, + depth=4, + dim_head=64, + heads=20, + num_queries=self._num_tokens, + ff_mult=4, + ).to(self.device, dtype=self.dtype) + + def build_ip_adapter( ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16 ) -> Union[IPAdapter, IPAdapterPlus]: @@ -228,6 +242,14 @@ def build_ip_adapter( is_plus = "proj.weight" not in state_dict["image_proj"] if is_plus: - return IPAdapterPlus(state_dict, device=device, dtype=dtype) + cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] + if cross_attention_dim == 768: + # SD1 IP-Adapter Plus + return IPAdapterPlus(state_dict, device=device, dtype=dtype) + elif cross_attention_dim == 2048: + # SDXL IP-Adapter Plus + return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) + else: + raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") else: return IPAdapter(state_dict, device=device, dtype=dtype)