2023-08-29 13:29:05 +00:00
|
|
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
|
|
|
# and modified as needed
|
|
|
|
|
2023-09-14 19:24:47 +00:00
|
|
|
from typing import Optional, Union
|
2023-09-08 19:39:22 +00:00
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
import torch
|
2023-09-07 18:10:42 +00:00
|
|
|
from PIL import Image
|
|
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
|
|
|
2023-10-06 15:05:25 +00:00
|
|
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
2023-09-25 22:28:10 +00:00
|
|
|
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
from .resampler import Resampler
|
|
|
|
|
|
|
|
|
|
|
|
class ImageProjModel(torch.nn.Module):
|
2023-09-08 20:00:58 +00:00
|
|
|
"""Image Projection Model"""
|
2023-09-04 23:37:12 +00:00
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
|
|
|
super().__init__()
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
self.cross_attention_dim = cross_attention_dim
|
|
|
|
self.clip_extra_context_tokens = clip_extra_context_tokens
|
|
|
|
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
|
|
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-09-14 18:14:35 +00:00
|
|
|
@classmethod
|
|
|
|
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
|
|
|
"""Initialize an ImageProjModel from a state_dict.
|
|
|
|
|
|
|
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
|
|
|
clip_extra_context_tokens (int, optional): Defaults to 4.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ImageProjModel
|
|
|
|
"""
|
|
|
|
cross_attention_dim = state_dict["norm.weight"].shape[0]
|
|
|
|
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
|
|
|
|
|
|
|
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
return model
|
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
def forward(self, image_embeds):
|
|
|
|
embeds = image_embeds
|
2023-09-04 23:37:12 +00:00
|
|
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
|
|
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
|
|
|
)
|
2023-08-29 07:51:55 +00:00
|
|
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
|
|
|
return clip_extra_context_tokens
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter:
|
2023-09-08 19:46:10 +00:00
|
|
|
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
|
|
|
|
2023-09-08 20:00:58 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-10-06 15:05:25 +00:00
|
|
|
state_dict: dict[str, torch.Tensor],
|
2023-09-08 20:00:58 +00:00
|
|
|
device: torch.device,
|
2023-09-12 23:09:10 +00:00
|
|
|
dtype: torch.dtype = torch.float16,
|
2023-09-08 20:00:58 +00:00
|
|
|
num_tokens: int = 4,
|
|
|
|
):
|
2023-09-12 23:09:10 +00:00
|
|
|
self.device = device
|
|
|
|
self.dtype = dtype
|
|
|
|
|
2023-09-08 20:00:58 +00:00
|
|
|
self._num_tokens = num_tokens
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-09-08 20:00:58 +00:00
|
|
|
self._clip_image_processor = CLIPImageProcessor()
|
2023-09-12 23:09:10 +00:00
|
|
|
|
2023-10-06 15:05:25 +00:00
|
|
|
self._image_proj_model = self._init_image_proj_model(state_dict["image_proj"])
|
2023-09-12 23:09:10 +00:00
|
|
|
|
2023-10-06 15:05:25 +00:00
|
|
|
self.attn_weights = IPAttentionWeights.from_state_dict(state_dict["ip_adapter"]).to(
|
|
|
|
self.device, dtype=self.dtype
|
|
|
|
)
|
2023-09-12 23:09:10 +00:00
|
|
|
|
|
|
|
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
|
|
|
self.device = device
|
|
|
|
if dtype is not None:
|
|
|
|
self.dtype = dtype
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-09-14 19:02:59 +00:00
|
|
|
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
2023-10-06 15:05:25 +00:00
|
|
|
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-09-25 22:28:10 +00:00
|
|
|
def calc_size(self):
|
2023-10-06 15:46:11 +00:00
|
|
|
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
2023-09-25 22:28:10 +00:00
|
|
|
|
2023-09-14 18:14:35 +00:00
|
|
|
def _init_image_proj_model(self, state_dict):
|
2023-09-14 19:02:59 +00:00
|
|
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
@torch.inference_mode()
|
2023-09-13 23:10:02 +00:00
|
|
|
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
2023-08-29 07:51:55 +00:00
|
|
|
if isinstance(pil_image, Image.Image):
|
|
|
|
pil_image = [pil_image]
|
2023-09-08 20:00:58 +00:00
|
|
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
2023-09-13 23:10:02 +00:00
|
|
|
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
2023-09-08 20:00:58 +00:00
|
|
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
|
|
|
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
2023-08-29 07:51:55 +00:00
|
|
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
2023-08-29 13:29:05 +00:00
|
|
|
|
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
class IPAdapterPlus(IPAdapter):
|
|
|
|
"""IP-Adapter with fine-grained features"""
|
|
|
|
|
2023-09-14 18:14:35 +00:00
|
|
|
def _init_image_proj_model(self, state_dict):
|
2023-09-14 19:02:59 +00:00
|
|
|
return Resampler.from_state_dict(
|
2023-09-14 18:14:35 +00:00
|
|
|
state_dict=state_dict,
|
2023-08-29 07:51:55 +00:00
|
|
|
depth=4,
|
|
|
|
dim_head=64,
|
|
|
|
heads=12,
|
2023-09-08 20:00:58 +00:00
|
|
|
num_queries=self._num_tokens,
|
2023-09-04 23:37:12 +00:00
|
|
|
ff_mult=4,
|
2023-09-12 23:09:10 +00:00
|
|
|
).to(self.device, dtype=self.dtype)
|
2023-08-29 13:29:05 +00:00
|
|
|
|
2023-08-29 07:51:55 +00:00
|
|
|
@torch.inference_mode()
|
2023-09-13 23:10:02 +00:00
|
|
|
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
2023-08-29 07:51:55 +00:00
|
|
|
if isinstance(pil_image, Image.Image):
|
|
|
|
pil_image = [pil_image]
|
2023-09-08 20:00:58 +00:00
|
|
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
2023-09-12 23:09:10 +00:00
|
|
|
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
2023-09-13 23:10:02 +00:00
|
|
|
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
2023-09-08 20:00:58 +00:00
|
|
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
2023-09-13 23:10:02 +00:00
|
|
|
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
|
|
|
-2
|
|
|
|
]
|
2023-09-08 20:00:58 +00:00
|
|
|
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
2023-08-29 13:29:05 +00:00
|
|
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
2023-09-14 19:24:47 +00:00
|
|
|
|
|
|
|
|
2023-10-04 21:55:19 +00:00
|
|
|
class IPAdapterPlusXL(IPAdapterPlus):
|
|
|
|
"""IP-Adapter Plus for SDXL."""
|
|
|
|
|
|
|
|
def _init_image_proj_model(self, state_dict):
|
|
|
|
return Resampler.from_state_dict(
|
|
|
|
state_dict=state_dict,
|
|
|
|
depth=4,
|
|
|
|
dim_head=64,
|
|
|
|
heads=20,
|
|
|
|
num_queries=self._num_tokens,
|
|
|
|
ff_mult=4,
|
|
|
|
).to(self.device, dtype=self.dtype)
|
|
|
|
|
|
|
|
|
2023-09-14 19:24:47 +00:00
|
|
|
def build_ip_adapter(
|
|
|
|
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
|
|
|
) -> Union[IPAdapter, IPAdapterPlus]:
|
|
|
|
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
|
|
|
|
|
|
|
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
|
|
|
# contains.
|
|
|
|
is_plus = "proj.weight" not in state_dict["image_proj"]
|
|
|
|
|
|
|
|
if is_plus:
|
2023-10-04 21:55:19 +00:00
|
|
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
|
|
|
if cross_attention_dim == 768:
|
|
|
|
# SD1 IP-Adapter Plus
|
|
|
|
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
|
|
|
elif cross_attention_dim == 2048:
|
|
|
|
# SDXL IP-Adapter Plus
|
|
|
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
|
|
|
else:
|
|
|
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
2023-09-14 19:24:47 +00:00
|
|
|
else:
|
|
|
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|