diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 3d07e685c5..4725aa98a3 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights +from invokeai.backend.ip_adapter.scales import Scales # Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict @@ -47,13 +48,16 @@ class IPAttnProcessor2_0(torch.nn.Module): the weight scale of image prompt. """ - def __init__(self, weights: list[IPAttentionProcessorWeights]): + def __init__(self, weights: list[IPAttentionProcessorWeights], scales: Scales): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - self.weights = weights + assert len(weights) == len(scales) + + self._weights = weights + self._scales = scales def __call__( self, @@ -119,9 +123,11 @@ class IPAttnProcessor2_0(torch.nn.Module): # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. assert ip_adapter_image_prompt_embeds is not None - assert len(ip_adapter_image_prompt_embeds) == len(self.weights) + assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - for ipa_embed, ipa_weights in zip(ip_adapter_image_prompt_embeds, self.weights): + for ipa_embed, ipa_weights, scale in zip( + ip_adapter_image_prompt_embeds, self._weights, self._scales.scales + ): # The batch dimensions should match. assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] # The channel dimensions should match. @@ -144,7 +150,7 @@ class IPAttnProcessor2_0(torch.nn.Module): ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + ipa_weights.scale * ip_hidden_states + hidden_states = hidden_states + scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/invokeai/backend/ip_adapter/ip_attention_weights.py b/invokeai/backend/ip_adapter/ip_attention_weights.py index e7ed9e9c76..9c3b8969c6 100644 --- a/invokeai/backend/ip_adapter/ip_attention_weights.py +++ b/invokeai/backend/ip_adapter/ip_attention_weights.py @@ -8,9 +8,8 @@ class IPAttentionProcessorWeights(torch.nn.Module): method. """ - def __init__(self, in_dim: int, out_dim: int, scale: float = 1.0): + def __init__(self, in_dim: int, out_dim: int): super().__init__() - self.scale = scale self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False) self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False) @@ -26,11 +25,6 @@ class IPAttentionWeights(torch.nn.Module): super().__init__() self._weights = weights - def set_scale(self, scale: float): - """Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection.""" - for w in self._weights.values(): - w.scale = scale - def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights: """Get the `IPAttentionProcessorWeights` for the idx'th attention processor.""" # Cast to int first, because we expect the key to represent an int. Then cast back to str, because diff --git a/invokeai/backend/ip_adapter/scales.py b/invokeai/backend/ip_adapter/scales.py new file mode 100644 index 0000000000..c4bf2b7a29 --- /dev/null +++ b/invokeai/backend/ip_adapter/scales.py @@ -0,0 +1,19 @@ +class Scales: + """The IP-Adapter scales for a patched UNet. This object can be used to dynamically change the scales for a patched + UNet. + """ + + def __init__(self, scales: list[float]): + self._scales = scales + + @property + def scales(self): + return self._scales + + @scales.setter + def scales(self, scales: list[float]): + assert len(scales) == len(self._scales) + self._scales = scales + + def __len__(self): + return len(self._scales) diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/ip_adapter/unet_patcher.py index ac9cf6de83..e45891d488 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/ip_adapter/unet_patcher.py @@ -4,9 +4,10 @@ from diffusers.models import UNet2DConditionModel from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.ip_adapter.scales import Scales -def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]): +def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter], scales: Scales): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention weights into them. @@ -32,15 +33,22 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[ else: # Collect the weights from each IP Adapter for the idx'th attention processor. attn_procs[name] = IPAttnProcessor2_0( - [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters] + [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters], scales ) return attn_procs @contextmanager def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]): - """A context manager that patches `unet` with IP-Adapter attention processors.""" - attn_procs = _prepare_attention_processors(unet, ip_adapters) + """A context manager that patches `unet` with IP-Adapter attention processors. + + Yields: + Scales: The Scales object, which can be used to dynamically alter the scales of the + IP-Adapters. + """ + scales = Scales([1.0] * len(ip_adapters)) + + attn_procs = _prepare_attention_processors(unet, ip_adapters, scales) orig_attn_processors = unet.attn_processors @@ -49,6 +57,6 @@ def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPA # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. unet.set_attn_processor(attn_procs) - yield None + yield scales finally: unet.set_attn_processor(orig_attn_processors) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index a2bd6457a0..d7bd8bff17 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention +from invokeai.backend.ip_adapter.unet_patcher import Scales, apply_ip_adapter_attention from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData from ..util import auto_detect_slice_size, normalize_device @@ -426,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return latents, attention_map_saver if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: - attn_ctx = self.invokeai_diffuser.custom_attention_context( + attn_ctx_mgr = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=conditioning_data.extra, step_count=len(self.scheduler.timesteps), @@ -435,14 +435,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): elif ip_adapter_data is not None: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. - attn_ctx = apply_ip_adapter_attention( + attn_ctx_mgr = apply_ip_adapter_attention( unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data] ) self.use_ip_adapter = True else: - attn_ctx = nullcontext() + attn_ctx_mgr = nullcontext() - with attn_ctx: + with attn_ctx_mgr as attn_ctx: if callback is not None: callback( PipelineIntermediateState( @@ -467,6 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, + attn_ctx=attn_ctx, ) latents = step_output.prev_sample @@ -514,6 +515,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, + attn_ctx: Optional[Scales] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -526,7 +528,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # handle IP-Adapter if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer - for single_ip_adapter_data in ip_adapter_data: + for i, single_ip_adapter_data in enumerate(ip_adapter_data): first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count) last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count) weight = ( @@ -536,10 +538,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) if step_index >= first_adapter_step and step_index <= last_adapter_step: # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(weight) + attn_ctx.scales[i] = weight else: # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(0.0) + attn_ctx.scales[i] = weight # Handle ControlNet(s) and T2I-Adapter(s) down_block_additional_residuals = None