mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for IPAdapterFull models. The changes are based on this upstream PR: https://github.com/tencent-ailab/IP-Adapter/pull/139 .
This commit is contained in:
parent
77933a0a85
commit
693c6cf5e4
@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class MLPProjModel(torch.nn.Module):
|
||||||
|
"""SD model with image prompt"""
|
||||||
|
|
||||||
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
||||||
|
torch.nn.LayerNorm(cross_attention_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
||||||
|
"""Initialize an MLPProjModel from a state_dict.
|
||||||
|
|
||||||
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MLPProjModel
|
||||||
|
"""
|
||||||
|
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||||
|
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||||
|
|
||||||
|
model = cls(cross_attention_dim, clip_embeddings_dim)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(self, image_embeds):
|
||||||
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter:
|
class IPAdapter:
|
||||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||||
|
|
||||||
@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterFull(IPAdapterPlus):
|
||||||
|
"""IP-Adapter Plus with full features."""
|
||||||
|
|
||||||
|
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
||||||
|
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlusXL(IPAdapterPlus):
|
class IPAdapterPlusXL(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus for SDXL."""
|
"""IP-Adapter Plus for SDXL."""
|
||||||
|
|
||||||
@ -149,11 +194,9 @@ def build_ip_adapter(
|
|||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||||
# contains.
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
||||||
|
|
||||||
if is_plus:
|
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
if cross_attention_dim == 768:
|
if cross_attention_dim == 768:
|
||||||
# SD1 IP-Adapter Plus
|
# SD1 IP-Adapter Plus
|
||||||
@ -163,5 +206,7 @@ def build_ip_adapter(
|
|||||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||||
|
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
||||||
|
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||||
|
@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device):
|
|||||||
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||||
"unet_model_name": "stable-diffusion-v1-5",
|
"unet_model_name": "stable-diffusion-v1-5",
|
||||||
},
|
},
|
||||||
|
# SD1.5, IPAdapterFull
|
||||||
|
{
|
||||||
|
"ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15",
|
||||||
|
"ip_adapter_model_name": "ip-adapter-full-face_sd15",
|
||||||
|
"base_model": BaseModelType.StableDiffusion1,
|
||||||
|
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"unet_model_name": "stable-diffusion-v1-5",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
Loading…
Reference in New Issue
Block a user