From 7afdefb0e55dfc71d15a92f420c0d7d44a51b883 Mon Sep 17 00:00:00 2001 From: user1 Date: Tue, 29 Aug 2023 00:51:55 -0700 Subject: [PATCH 1/5] Core ip_adapter files from https://github.com/tencent-ailab/IP-Adapter Copied into InvokeAI since IP-Adapter repo is not a package. Is there a better way to do this for non-packaged Python code while still keeping InvokeAI install easy? --- invokeai/backend/ip_adapter/__init__.py | 1 + .../backend/ip_adapter/attention_processor.py | 390 ++++++++++++++++++ invokeai/backend/ip_adapter/ip_adapter.py | 243 +++++++++++ invokeai/backend/ip_adapter/resampler.py | 121 ++++++ invokeai/backend/ip_adapter/utils.py | 368 +++++++++++++++++ 5 files changed, 1123 insertions(+) create mode 100644 invokeai/backend/ip_adapter/__init__.py create mode 100644 invokeai/backend/ip_adapter/attention_processor.py create mode 100644 invokeai/backend/ip_adapter/ip_adapter.py create mode 100644 invokeai/backend/ip_adapter/resampler.py create mode 100644 invokeai/backend/ip_adapter/utils.py diff --git a/invokeai/backend/ip_adapter/__init__.py b/invokeai/backend/ip_adapter/__init__.py new file mode 100644 index 0000000000..852ee25813 --- /dev/null +++ b/invokeai/backend/ip_adapter/__init__.py @@ -0,0 +1 @@ +from .ip_adapter import IPAdapter, IPAdapterXL, IPAdapterPlus diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py new file mode 100644 index 0000000000..4754be00e0 --- /dev/null +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -0,0 +1,390 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + text_context_len (`int`, defaults to 77): + The context length of the text features. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.text_context_len = text_context_len + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # split hidden states + encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + text_context_len (`int`, defaults to 77): + The context length of the text features. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.text_context_len = text_context_len + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # split hidden states + encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py new file mode 100644 index 0000000000..5d5d0af71b --- /dev/null +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -0,0 +1,243 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor +from PIL import Image + +from .utils import is_torch2_available +if is_torch2_available: + from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +else: + from .attention_processor import IPAttnProcessor, AttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class IPAdapter: + + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): + + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, + scale=1.0).to(self.device, dtype=torch.float16) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=-1, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if isinstance(pil_image, Image.Image): + num_prompts = 1 + else: + num_prompts = len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds = self.pipe._encode_prompt( + prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=-1, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if isinstance(pil_image, Image.Image): + num_prompts = 1 + else: + num_prompts = len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt( + prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4 + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds \ No newline at end of file diff --git a/invokeai/backend/ip_adapter/resampler.py b/invokeai/backend/ip_adapter/resampler.py new file mode 100644 index 0000000000..4521c8c3e6 --- /dev/null +++ b/invokeai/backend/ip_adapter/resampler.py @@ -0,0 +1,121 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/invokeai/backend/ip_adapter/utils.py b/invokeai/backend/ip_adapter/utils.py new file mode 100644 index 0000000000..10218092ed --- /dev/null +++ b/invokeai/backend/ip_adapter/utils.py @@ -0,0 +1,368 @@ +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from diffusers.utils import is_compiled_module +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.models import ControlNetModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + + + +def is_torch2_available(): + return hasattr(F, "scaled_dot_product_attention") + + +@torch.no_grad() +def generate( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, +): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds[:, :77, :].chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds[:, :77, :] + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file From 69d37217b872ce4be6ad0f5058d0765e3fb5adbe Mon Sep 17 00:00:00 2001 From: user1 Date: Tue, 29 Aug 2023 06:29:05 -0700 Subject: [PATCH 2/5] Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top. --- .../backend/ip_adapter/attention_processor.py | 34 ++++--- invokeai/backend/ip_adapter/ip_adapter.py | 93 +++++++++++-------- invokeai/backend/ip_adapter/resampler.py | 30 +++--- invokeai/backend/ip_adapter/utils.py | 5 +- 4 files changed, 94 insertions(+), 68 deletions(-) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 4754be00e0..de9b367b7d 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -1,3 +1,7 @@ +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) +# and modified as needed + +# tencent-ailab comment: # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import torch import torch.nn as nn @@ -74,8 +78,8 @@ class AttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. @@ -134,7 +138,7 @@ class IPAttnProcessor(nn.Module): encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - + # split hidden states encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] @@ -148,18 +152,18 @@ class IPAttnProcessor(nn.Module): attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - + ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) - + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj @@ -176,8 +180,8 @@ class IPAttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -264,8 +268,8 @@ class AttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. @@ -355,11 +359,11 @@ class IPAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -368,10 +372,10 @@ class IPAttnProcessor2_0(torch.nn.Module): ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) - + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 5d5d0af71b..ddec16eebc 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -1,3 +1,6 @@ +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) +# and modified as needed + import os from typing import List @@ -6,11 +9,14 @@ from diffusers import StableDiffusionPipeline from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from PIL import Image -from .utils import is_torch2_available -if is_torch2_available: - from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor -else: - from .attention_processor import IPAttnProcessor, AttnProcessor +# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor +# so for now falling back to the default versions +# from .utils import is_torch2_available +# if is_torch2_available: +# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +# else: +# from .attention_processor import IPAttnProcessor, AttnProcessor +from .attention_processor import IPAttnProcessor, AttnProcessor from .resampler import Resampler @@ -18,12 +24,12 @@ class ImageProjModel(torch.nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() - + self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) - + def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) @@ -32,25 +38,29 @@ class ImageProjModel(torch.nn.Module): class IPAdapter: - + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): - + self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens - - self.pipe = sd_pipe.to(self.device) + + # FIXME: + # InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used + # so for now assuming that pipeline is already on the correct device + # self.pipe = sd_pipe.to(self.device) + self.pipe = sd_pipe self.set_ip_adapter() - + # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) self.clip_image_processor = CLIPImageProcessor() # image proj model self.image_proj_model = self.init_proj() - + self.load_ip_adapter() - + def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, @@ -58,10 +68,12 @@ class IPAdapter: clip_extra_context_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) return image_proj_model - + def set_ip_adapter(self): unet = self.pipe.unet attn_procs = {} + print("Original UNet Attn Processors count:", len(unet.attn_processors)) + print(unet.attn_processors.keys()) for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -75,16 +87,19 @@ class IPAdapter: if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: + print("swapping in IPAttnProcessor for", name) attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) - + print("Modified UNet Attn Processors count:", len(unet.attn_processors)) + print(unet.attn_processors.keys()) + def load_ip_adapter(self): state_dict = torch.load(self.ip_ckpt, map_location="cpu") self.image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) - + @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): @@ -94,12 +109,14 @@ class IPAdapter: image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds - + def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale - + + # IPAdapter.generate() method is not used for InvokeAI + # left here for reference def generate( self, pil_image, @@ -113,22 +130,22 @@ class IPAdapter: **kwargs, ): self.set_scale(scale) - + if isinstance(pil_image, Image.Image): num_prompts = 1 else: num_prompts = len(pil_image) - + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -142,7 +159,7 @@ class IPAdapter: negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -152,13 +169,13 @@ class IPAdapter: generator=generator, **kwargs, ).images - + return images - - + + class IPAdapterXL(IPAdapter): """SDXL""" - + def generate( self, pil_image, @@ -171,22 +188,22 @@ class IPAdapterXL(IPAdapter): **kwargs, ): self.set_scale(scale) - + if isinstance(pil_image, Image.Image): num_prompts = 1 else: num_prompts = len(pil_image) - + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -199,7 +216,7 @@ class IPAdapterXL(IPAdapter): prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -210,10 +227,10 @@ class IPAdapterXL(IPAdapter): generator=generator, **kwargs, ).images - + return images - - + + class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" @@ -229,7 +246,7 @@ class IPAdapterPlus(IPAdapter): ff_mult=4 ).to(self.device, dtype=torch.float16) return image_proj_model - + @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): @@ -240,4 +257,4 @@ class IPAdapterPlus(IPAdapter): image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds \ No newline at end of file + return image_prompt_embeds, uncond_image_prompt_embeds diff --git a/invokeai/backend/ip_adapter/resampler.py b/invokeai/backend/ip_adapter/resampler.py index 4521c8c3e6..327ef7c140 100644 --- a/invokeai/backend/ip_adapter/resampler.py +++ b/invokeai/backend/ip_adapter/resampler.py @@ -1,4 +1,6 @@ -# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) + +# tencent ailab comment: modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py import math import torch @@ -14,8 +16,8 @@ def FeedForward(dim, mult=4): nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) - - + + def reshape_tensor(x, heads): bs, length, width = x.shape #(bs, length, width) --> (bs, length, n_heads, dim_per_head) @@ -53,13 +55,13 @@ class PerceiverAttention(nn.Module): """ x = self.norm1(x) latents = self.norm2(latents) - + b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) - + q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) @@ -69,7 +71,7 @@ class PerceiverAttention(nn.Module): weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v - + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) @@ -88,14 +90,14 @@ class Resampler(nn.Module): ff_mult=4, ): super().__init__() - + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) - + self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) - + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( @@ -108,14 +110,14 @@ class Resampler(nn.Module): ) def forward(self, x): - + latents = self.latents.repeat(x.size(0), 1, 1) - + x = self.proj_in(x) - + for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents - + latents = self.proj_out(latents) - return self.norm_out(latents) \ No newline at end of file + return self.norm_out(latents) diff --git a/invokeai/backend/ip_adapter/utils.py b/invokeai/backend/ip_adapter/utils.py index 10218092ed..e120a9e2b4 100644 --- a/invokeai/backend/ip_adapter/utils.py +++ b/invokeai/backend/ip_adapter/utils.py @@ -1,3 +1,6 @@ +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) +# and modified as needed + import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -365,4 +368,4 @@ def generate( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 9ed4d487d25c0f789030b089e57d3f588bece1dd Mon Sep 17 00:00:00 2001 From: user1 Date: Tue, 29 Aug 2023 06:31:24 -0700 Subject: [PATCH 3/5] Working POC of IP-Adapters. Not fully nodified yet. --- .../stable_diffusion/diffusers_pipeline.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 2d1894c896..f40e4dae03 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -34,6 +34,7 @@ from .diffusion import ( BasicConditioningInfo, ) from ..util import normalize_device, auto_detect_slice_size +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter @dataclass @@ -357,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, + ip_adapter_image: Optional[PIL.Image] = None, mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, seed: Optional[int] = None, @@ -408,6 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data, additional_guidance=additional_guidance, control_data=control_data, + ip_adapter_image=ip_adapter_image, callback=callback, ) finally: @@ -427,8 +430,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): *, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, + ip_adapter_image: Optional[PIL.Image] = None, callback: Callable[[PipelineIntermediateState], None] = None, ): + self._adjust_memory_efficient_attention(latents) if additional_guidance is None: additional_guidance = [] @@ -439,6 +444,55 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents, attention_map_saver + print("ip_adapter_image: ", type(ip_adapter_image)) + if ip_adapter_image is not None: + # initialize IPAdapter + print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height) + clip_image_encoder_path = "ip_adapter_models_sd_1.5/image_encoder/" + ip_adapter_model_path = "ip_adapter_models_sd_1.5/ip-adapter_sd15.bin" + # FIXME: + # WARNING! + # IPAdapter constructor modifies UNet model in-place + # Adds additional cross-attention layers to UNet model for image embedding + # need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter + # and how to undo if ip_adapter_image is removed + # use existing model management context etc? + # + ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline + clip_image_encoder_path, # hardwiring to manually downloaded dir for first pass + ip_adapter_model_path, # hardwiring to manually downloaded loc for first pass + "cuda") # hardwiring CUDA GPU for first pass + # IP-Adapter ==> add additional cross-attention layers to UNet model here? + print("ip_adapter:", ip_adapter) + + # get image embedding from CLIP and ImageProjModel + print("getting image embeddings from IP-Adapter...") + num_samples = 1 # hardwiring for first pass + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image) + print("image cond embeds shape:", image_prompt_embeds.shape) + print("image uncond embeds shape:", uncond_image_prompt_embeds.shape) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + print("image cond embeds shape:", image_prompt_embeds.shape) + print("image uncond embeds shape:", uncond_image_prompt_embeds.shape) + + # IP-Adapter: run IP-Adapter model here? + # and add output as additional cross-attention layers + text_prompt_embeds = conditioning_data.text_embeddings.embeds + uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds + print("text embeds shape:", text_prompt_embeds.shape) + concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1) + concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1) + print("concat embeds shape:", concat_prompt_embeds.shape) + conditioning_data.text_embeddings.embeds = concat_prompt_embeds + conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds + else: + image_prompt_embeds = None + uncond_image_prompt_embeds = None + extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, From 35b7ae90aeed9d3793cac11efc0e12d490442186 Mon Sep 17 00:00:00 2001 From: user1 Date: Tue, 29 Aug 2023 06:32:48 -0700 Subject: [PATCH 4/5] Working POC for IP-Adapters. Not fully nodified yet, lots of caveats, hardwired model paths, etc. --- invokeai/app/invocations/latent.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 80988f3c71..f5ef356134 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -52,9 +52,10 @@ from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField -DEFAULT_PRECISION = choose_precision(choose_torch_device()) +DEFAULT_PRECISION = choose_precision(choose_torch_device()) + SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] @@ -191,6 +192,7 @@ class DenoiseLatentsInvocation(BaseInvocation): default=None, description=FieldDescriptions.mask, ) + ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection) @validator("cfg_scale") def ge_one(cls, v): @@ -476,6 +478,13 @@ class DenoiseLatentsInvocation(BaseInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) + if self.ip_adapter_image is not None: + print("ip_adapter_image:", self.ip_adapter_image) + unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name) + print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image) + else: + unwrapped_ip_adapter_image = None + control_data = self.prep_control_data( model=pipeline, context=context, @@ -504,7 +513,8 @@ class DenoiseLatentsInvocation(BaseInvocation): masked_latents=masked_latents, num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] + control_data=control_data, # list[ControlNetData], + ip_adapter_image=unwrapped_ip_adapter_image, callback=step_callback, ) From 5f4a62810edf5c2dfb687b67844d313896f99b7a Mon Sep 17 00:00:00 2001 From: user1 Date: Tue, 29 Aug 2023 10:42:42 -0700 Subject: [PATCH 5/5] Added ip_adapter_strength parameter to adjust weighting of IP-Adapter's added cross-attention layers --- invokeai/app/invocations/latent.py | 5 ++++- invokeai/backend/stable_diffusion/diffusers_pipeline.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index f5ef356134..5f00995414 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -192,7 +192,9 @@ class DenoiseLatentsInvocation(BaseInvocation): default=None, description=FieldDescriptions.mask, ) - ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection) + ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6) + ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float, + title="IP Adapter Strength", ui_order=7) @validator("cfg_scale") def ge_one(cls, v): @@ -515,6 +517,7 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData], ip_adapter_image=unwrapped_ip_adapter_image, + ip_adapter_strength=self.ip_adapter_strength, callback=step_callback, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index f40e4dae03..3335c8866f 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -359,6 +359,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, ip_adapter_image: Optional[PIL.Image] = None, + ip_adapter_strength: float = 1.0, mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, seed: Optional[int] = None, @@ -411,6 +412,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance=additional_guidance, control_data=control_data, ip_adapter_image=ip_adapter_image, + ip_adapter_strength=ip_adapter_strength, callback=callback, ) finally: @@ -431,6 +433,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, ip_adapter_image: Optional[PIL.Image] = None, + ip_adapter_strength: float = 1.0, callback: Callable[[PipelineIntermediateState], None] = None, ): @@ -463,6 +466,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ip_adapter_model_path, # hardwiring to manually downloaded loc for first pass "cuda") # hardwiring CUDA GPU for first pass # IP-Adapter ==> add additional cross-attention layers to UNet model here? + ip_adapter.set_scale(ip_adapter_strength) print("ip_adapter:", ip_adapter) # get image embedding from CLIP and ImageProjModel