mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for IPAdapterFull models. The changes are based on this upstream PR: https://github.com/tencent-ailab/IP-Adapter/pull/139 .
This commit is contained in:
parent
77933a0a85
commit
693c6cf5e4
@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
"""SD model with image prompt"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
||||
torch.nn.LayerNorm(cross_attention_dim),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
||||
"""Initialize an MLPProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||
|
||||
Returns:
|
||||
MLPProjModel
|
||||
"""
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
|
||||
model = cls(cross_attention_dim, clip_embeddings_dim)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter):
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
class IPAdapterFull(IPAdapterPlus):
|
||||
"""IP-Adapter Plus with full features."""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
class IPAdapterPlusXL(IPAdapterPlus):
|
||||
"""IP-Adapter Plus for SDXL."""
|
||||
|
||||
@ -149,11 +194,9 @@ def build_ip_adapter(
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
|
||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
||||
# contains.
|
||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
||||
|
||||
if is_plus:
|
||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
# SD1 IP-Adapter Plus
|
||||
@ -163,5 +206,7 @@ def build_ip_adapter(
|
||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||
|
@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device):
|
||||
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"unet_model_name": "stable-diffusion-v1-5",
|
||||
},
|
||||
# SD1.5, IPAdapterFull
|
||||
{
|
||||
"ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15",
|
||||
"ip_adapter_model_name": "ip-adapter-full-face_sd15",
|
||||
"base_model": BaseModelType.StableDiffusion1,
|
||||
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"unet_model_name": "stable-diffusion-v1-5",
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.slow
|
||||
|
Loading…
Reference in New Issue
Block a user