diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 2e0944a969..e173547202 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -1,7 +1,10 @@ # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # and modified as needed +from contextlib import contextmanager + import torch +from diffusers.models import UNet2DConditionModel # FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor # so for now falling back to the default versions @@ -38,18 +41,14 @@ class ImageProjModel(torch.nn.Module): class IPAdapter: - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): + def __init__(self, unet: UNet2DConditionModel, image_encoder_path, ip_ckpt, device, num_tokens=4): + self._unet = unet self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens - # FIXME: - # InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used - # so for now assuming that pipeline is already on the correct device - # self.pipe = sd_pipe.to(self.device) - self.pipe = sd_pipe - self.set_ip_adapter() + self._attn_processors = self._prepare_attention_processors() # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( @@ -63,44 +62,55 @@ class IPAdapter: def init_proj(self): image_proj_model = ImageProjModel( - cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + cross_attention_dim=self._unet.config.cross_attention_dim, clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_extra_context_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) return image_proj_model - def set_ip_adapter(self): - unet = self.pipe.unet + def _prepare_attention_processors(self): + """Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter + attention weights into them. + """ attn_procs = {} - print("Original UNet Attn Processors count:", len(unet.attn_processors)) - print(unet.attn_processors.keys()) - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + for name in self._unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] + hidden_size = self._unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] + hidden_size = self._unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: - print("swapping in IPAttnProcessor for", name) attn_procs[name] = IPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, ).to(self.device, dtype=torch.float16) - unet.set_attn_processor(attn_procs) - print("Modified UNet Attn Processors count:", len(unet.attn_processors)) - print(unet.attn_processors.keys()) + return attn_procs + + @contextmanager + def apply_ip_adapter_attention(self): + """A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active. + + Yields: + None + """ + orig_attn_processors = self._unet.attn_processors + try: + self._unet.set_attn_processor(self._attn_processors) + yield None + finally: + self._unet.set_attn_processor(orig_attn_processors) def load_ip_adapter(self): state_dict = torch.load(self.ip_ckpt, map_location="cpu") self.image_proj_model.load_state_dict(state_dict["image_proj"]) - ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers = torch.nn.ModuleList(self._attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) @torch.inference_mode() @@ -114,139 +124,15 @@ class IPAdapter: return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): - for attn_processor in self.pipe.unet.attn_processors.values(): + for attn_processor in self._attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale - # IPAdapter.generate() method is not used for InvokeAI. Left here for reference: - # def generate( - # self, - # pil_image, - # prompt=None, - # negative_prompt=None, - # scale=1.0, - # num_samples=4, - # seed=-1, - # guidance_scale=7.5, - # num_inference_steps=30, - # **kwargs, - # ): - # self.set_scale(scale) - - # if isinstance(pil_image, Image.Image): - # num_prompts = 1 - # else: - # num_prompts = len(pil_image) - - # if prompt is None: - # prompt = "best quality, high quality" - # if negative_prompt is None: - # negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - # if not isinstance(prompt, List): - # prompt = [prompt] * num_prompts - # if not isinstance(negative_prompt, List): - # negative_prompt = [negative_prompt] * num_prompts - - # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) - # bs_embed, seq_len, _ = image_prompt_embeds.shape - # image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - # image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - # uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - # uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - # with torch.inference_mode(): - # prompt_embeds = self.pipe._encode_prompt( - # prompt, - # device=self.device, - # num_images_per_prompt=num_samples, - # do_classifier_free_guidance=True, - # negative_prompt=negative_prompt, - # ) - # negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) - # prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) - # negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - - # generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None - # images = self.pipe( - # prompt_embeds=prompt_embeds, - # negative_prompt_embeds=negative_prompt_embeds, - # guidance_scale=guidance_scale, - # num_inference_steps=num_inference_steps, - # generator=generator, - # **kwargs, - # ).images - - # return images - class IPAdapterXL(IPAdapter): """SDXL""" pass - # IPAdapterXL.generate() method is not used for InvokeAI. Left here for reference: - # def generate( - # self, - # pil_image, - # prompt=None, - # negative_prompt=None, - # scale=1.0, - # num_samples=4, - # seed=-1, - # num_inference_steps=30, - # **kwargs, - # ): - # self.set_scale(scale) - - # if isinstance(pil_image, Image.Image): - # num_prompts = 1 - # else: - # num_prompts = len(pil_image) - - # if prompt is None: - # prompt = "best quality, high quality" - # if negative_prompt is None: - # negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - # if not isinstance(prompt, List): - # prompt = [prompt] * num_prompts - # if not isinstance(negative_prompt, List): - # negative_prompt = [negative_prompt] * num_prompts - - # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) - # bs_embed, seq_len, _ = image_prompt_embeds.shape - # image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - # image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - # uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - # uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - # with torch.inference_mode(): - # ( - # prompt_embeds, - # negative_prompt_embeds, - # pooled_prompt_embeds, - # negative_pooled_prompt_embeds, - # ) = self.pipe.encode_prompt( - # prompt, - # num_images_per_prompt=num_samples, - # do_classifier_free_guidance=True, - # negative_prompt=negative_prompt, - # ) - # prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) - # negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - - # generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None - # images = self.pipe( - # prompt_embeds=prompt_embeds, - # negative_prompt_embeds=negative_prompt_embeds, - # pooled_prompt_embeds=pooled_prompt_embeds, - # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - # num_inference_steps=num_inference_steps, - # generator=generator, - # **kwargs, - # ).images - - # return images class IPAdapterPlus(IPAdapter): @@ -254,13 +140,13 @@ class IPAdapterPlus(IPAdapter): def init_proj(self): image_proj_model = Resampler( - dim=self.pipe.unet.config.cross_attention_dim, + dim=self._unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, + output_dim=self._unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 138e3c9cea..66417e56df 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -2,6 +2,7 @@ from __future__ import annotations import dataclasses import inspect +from contextlib import nullcontext from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Union @@ -419,21 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if ip_adapter_data is not None: # Initialize IPAdapter - # FIXME: - # WARNING! - # IPAdapter constructor modifies UNet model in-place - # Adds additional cross-attention layers to UNet model for image embedding - # need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter - # and how to undo if ip_adapter_image is removed - # Should reimplement to use existing model management context etc. - # + # TODO(ryand): Refactor to use model management for the IP-Adapter. if "sdxl" in ip_adapter_data.ip_adapter_model: ip_adapter = IPAdapterXL( - self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device + self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device ) elif "plus" in ip_adapter_data.ip_adapter_model: ip_adapter = IPAdapterPlus( - self, # IPAdapterPlus first arg is StableDiffusionPipeline + self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device, @@ -441,7 +435,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) else: ip_adapter = IPAdapter( - self, # IPAdapter first arg is StableDiffusionPipeline + self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device, @@ -454,13 +448,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): image_prompt_embeds, uncond_image_prompt_embeds ) - # TODO(ryand): Apply IP-Adapter or custom attention control - extra_conditioning_info = conditioning_data.extra - with self.invokeai_diffuser.custom_attention_context( - self.invokeai_diffuser.model, - extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), - ): + if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: + attn_ctx = self.invokeai_diffuser.custom_attention_context( + self.invokeai_diffuser.model, + extra_conditioning_info=conditioning_data.extra, + step_count=len(self.scheduler.timesteps), + ) + elif ip_adapter_data is not None: + # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? + # As it is now, the IP-Adapter will silently be skipped. + attn_ctx = ip_adapter.apply_ip_adapter_attention() + else: + attn_ctx = nullcontext() + + with attn_ctx: if callback is not None: callback( PipelineIntermediateState( diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index 35d4800859..85d1a5ddcd 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -11,16 +11,17 @@ import diffusers import psutil import torch from compel.cross_attention_control import Arguments -from diffusers.models.unet_2d_condition import UNet2DConditionModel -from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import ( Attention, + AttentionProcessor, AttnProcessor, SlicedAttnProcessor, ) +from diffusers.models.unet_2d_condition import UNet2DConditionModel from torch import nn import invokeai.backend.util.logging as logger + from ...util import torch_dtype @@ -380,10 +381,13 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[ # non-fatal error but .swap() won't work. logger.error( f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " - + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching" + " failed " + "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " - + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " - + "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who" + " knows " + + "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will" + " not " + "work properly until it is fixed." ) return attention_module_tuples @@ -581,6 +585,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): attention_mask=None, # kwargs swap_cross_attn_context: SwapCrossAttnContext = None, + **kwargs, ): attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 8473fa7bcc..f79f5e2559 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -67,30 +67,26 @@ class InvokeAIDiffuserComponent: @contextmanager def custom_attention_context( self, - unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs + unet: UNet2DConditionModel, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int, ): - old_attn_processors = None - if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control): - old_attn_processors = unet.attn_processors - # Load lora conditions into the model - if extra_conditioning_info.wants_cross_attention_control: - self.cross_attention_control_context = Context( - arguments=extra_conditioning_info.cross_attention_control_args, - step_count=step_count, - ) - setup_cross_attention_control_attention_processors( - unet, - self.cross_attention_control_context, - ) + old_attn_processors = unet.attn_processors try: + self.cross_attention_control_context = Context( + arguments=extra_conditioning_info.cross_attention_control_args, + step_count=step_count, + ) + setup_cross_attention_control_attention_processors( + unet, + self.cross_attention_control_context, + ) + yield None finally: self.cross_attention_control_context = None - if old_attn_processors is not None: - unet.set_attn_processor(old_attn_processors) + unet.set_attn_processor(old_attn_processors) # TODO resuscitate attention map saving # self.remove_attention_map_saving()