diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 4754be00e0..de9b367b7d 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -1,3 +1,7 @@ +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) +# and modified as needed + +# tencent-ailab comment: # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import torch import torch.nn as nn @@ -74,8 +78,8 @@ class AttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. @@ -134,7 +138,7 @@ class IPAttnProcessor(nn.Module): encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - + # split hidden states encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] @@ -148,18 +152,18 @@ class IPAttnProcessor(nn.Module): attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - + ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) - + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj @@ -176,8 +180,8 @@ class IPAttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -264,8 +268,8 @@ class AttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - + + class IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. @@ -355,11 +359,11 @@ class IPAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -368,10 +372,10 @@ class IPAttnProcessor2_0(torch.nn.Module): ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) - + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 5d5d0af71b..ddec16eebc 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -1,3 +1,6 @@ +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) +# and modified as needed + import os from typing import List @@ -6,11 +9,14 @@ from diffusers import StableDiffusionPipeline from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from PIL import Image -from .utils import is_torch2_available -if is_torch2_available: - from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor -else: - from .attention_processor import IPAttnProcessor, AttnProcessor +# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor +# so for now falling back to the default versions +# from .utils import is_torch2_available +# if is_torch2_available: +# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +# else: +# from .attention_processor import IPAttnProcessor, AttnProcessor +from .attention_processor import IPAttnProcessor, AttnProcessor from .resampler import Resampler @@ -18,12 +24,12 @@ class ImageProjModel(torch.nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() - + self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) - + def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) @@ -32,25 +38,29 @@ class ImageProjModel(torch.nn.Module): class IPAdapter: - + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): - + self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens - - self.pipe = sd_pipe.to(self.device) + + # FIXME: + # InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used + # so for now assuming that pipeline is already on the correct device + # self.pipe = sd_pipe.to(self.device) + self.pipe = sd_pipe self.set_ip_adapter() - + # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) self.clip_image_processor = CLIPImageProcessor() # image proj model self.image_proj_model = self.init_proj() - + self.load_ip_adapter() - + def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, @@ -58,10 +68,12 @@ class IPAdapter: clip_extra_context_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) return image_proj_model - + def set_ip_adapter(self): unet = self.pipe.unet attn_procs = {} + print("Original UNet Attn Processors count:", len(unet.attn_processors)) + print(unet.attn_processors.keys()) for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -75,16 +87,19 @@ class IPAdapter: if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: + print("swapping in IPAttnProcessor for", name) attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) - + print("Modified UNet Attn Processors count:", len(unet.attn_processors)) + print(unet.attn_processors.keys()) + def load_ip_adapter(self): state_dict = torch.load(self.ip_ckpt, map_location="cpu") self.image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) - + @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): @@ -94,12 +109,14 @@ class IPAdapter: image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds - + def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale - + + # IPAdapter.generate() method is not used for InvokeAI + # left here for reference def generate( self, pil_image, @@ -113,22 +130,22 @@ class IPAdapter: **kwargs, ): self.set_scale(scale) - + if isinstance(pil_image, Image.Image): num_prompts = 1 else: num_prompts = len(pil_image) - + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -142,7 +159,7 @@ class IPAdapter: negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -152,13 +169,13 @@ class IPAdapter: generator=generator, **kwargs, ).images - + return images - - + + class IPAdapterXL(IPAdapter): """SDXL""" - + def generate( self, pil_image, @@ -171,22 +188,22 @@ class IPAdapterXL(IPAdapter): **kwargs, ): self.set_scale(scale) - + if isinstance(pil_image, Image.Image): num_prompts = 1 else: num_prompts = len(pil_image) - + if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - + if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -199,7 +216,7 @@ class IPAdapterXL(IPAdapter): prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.pipe( prompt_embeds=prompt_embeds, @@ -210,10 +227,10 @@ class IPAdapterXL(IPAdapter): generator=generator, **kwargs, ).images - + return images - - + + class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" @@ -229,7 +246,7 @@ class IPAdapterPlus(IPAdapter): ff_mult=4 ).to(self.device, dtype=torch.float16) return image_proj_model - + @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): @@ -240,4 +257,4 @@ class IPAdapterPlus(IPAdapter): image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds \ No newline at end of file + return image_prompt_embeds, uncond_image_prompt_embeds diff --git a/invokeai/backend/ip_adapter/resampler.py b/invokeai/backend/ip_adapter/resampler.py index 4521c8c3e6..327ef7c140 100644 --- a/invokeai/backend/ip_adapter/resampler.py +++ b/invokeai/backend/ip_adapter/resampler.py @@ -1,4 +1,6 @@ -# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) + +# tencent ailab comment: modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py import math import torch @@ -14,8 +16,8 @@ def FeedForward(dim, mult=4): nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) - - + + def reshape_tensor(x, heads): bs, length, width = x.shape #(bs, length, width) --> (bs, length, n_heads, dim_per_head) @@ -53,13 +55,13 @@ class PerceiverAttention(nn.Module): """ x = self.norm1(x) latents = self.norm2(latents) - + b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) - + q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) @@ -69,7 +71,7 @@ class PerceiverAttention(nn.Module): weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v - + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) @@ -88,14 +90,14 @@ class Resampler(nn.Module): ff_mult=4, ): super().__init__() - + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) - + self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) - + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( @@ -108,14 +110,14 @@ class Resampler(nn.Module): ) def forward(self, x): - + latents = self.latents.repeat(x.size(0), 1, 1) - + x = self.proj_in(x) - + for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents - + latents = self.proj_out(latents) - return self.norm_out(latents) \ No newline at end of file + return self.norm_out(latents) diff --git a/invokeai/backend/ip_adapter/utils.py b/invokeai/backend/ip_adapter/utils.py index 10218092ed..e120a9e2b4 100644 --- a/invokeai/backend/ip_adapter/utils.py +++ b/invokeai/backend/ip_adapter/utils.py @@ -1,3 +1,6 @@ +# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) +# and modified as needed + import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -365,4 +368,4 @@ def generate( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)