Add support for IPAdapterFull models. The changes are based on this upstream PR: https://github.com/tencent-ailab/IP-Adapter/pull/139 .

This commit is contained in:
Ryan Dick 2023-11-29 15:10:45 -05:00 committed by Kent Keirsey
parent 77933a0a85
commit 693c6cf5e4
2 changed files with 59 additions and 6 deletions

View File

@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module):
return clip_extra_context_tokens return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim),
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Args:
state_dict (dict[torch.Tensor]): The state_dict of model weights.
Returns:
MLPProjModel
"""
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
model = cls(cross_attention_dim, clip_embeddings_dim)
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter: class IPAdapter:
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter):
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features."""
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
class IPAdapterPlusXL(IPAdapterPlus): class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL.""" """IP-Adapter Plus for SDXL."""
@ -149,11 +194,9 @@ def build_ip_adapter(
) -> Union[IPAdapter, IPAdapterPlus]: ) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
# contains. return IPAdapter(state_dict, device=device, dtype=dtype)
is_plus = "proj.weight" not in state_dict["image_proj"] elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
if is_plus:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768: if cross_attention_dim == 768:
# SD1 IP-Adapter Plus # SD1 IP-Adapter Plus
@ -163,5 +206,7 @@ def build_ip_adapter(
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
else: else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
return IPAdapterFull(state_dict, device=device, dtype=dtype)
else: else:
return IPAdapter(state_dict, device=device, dtype=dtype) raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")

View File

@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device):
"unet_model_id": "runwayml/stable-diffusion-v1-5", "unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5", "unet_model_name": "stable-diffusion-v1-5",
}, },
# SD1.5, IPAdapterFull
{
"ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15",
"ip_adapter_model_name": "ip-adapter-full-face_sd15",
"base_model": BaseModelType.StableDiffusion1,
"unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5",
},
], ],
) )
@pytest.mark.slow @pytest.mark.slow