diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 8261b29dbb..36f91566d4 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -10,7 +10,6 @@ import torch import torchvision.transforms as T from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.image_processor import VaeImageProcessor -from diffusers.models import UNet2DConditionModel from diffusers.models.adapter import FullAdapterXL, T2IAdapter from diffusers.models.attention_processor import ( AttnProcessor2_0, diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 4725aa98a3..2873c52322 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -9,7 +9,6 @@ import torch.nn.functional as F from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights -from invokeai.backend.ip_adapter.scales import Scales # Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict @@ -48,7 +47,7 @@ class IPAttnProcessor2_0(torch.nn.Module): the weight scale of image prompt. """ - def __init__(self, weights: list[IPAttentionProcessorWeights], scales: Scales): + def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -125,9 +124,7 @@ class IPAttnProcessor2_0(torch.nn.Module): assert ip_adapter_image_prompt_embeds is not None assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - for ipa_embed, ipa_weights, scale in zip( - ip_adapter_image_prompt_embeds, self._weights, self._scales.scales - ): + for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales): # The batch dimensions should match. assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] # The channel dimensions should match. diff --git a/invokeai/backend/ip_adapter/scales.py b/invokeai/backend/ip_adapter/scales.py deleted file mode 100644 index c4bf2b7a29..0000000000 --- a/invokeai/backend/ip_adapter/scales.py +++ /dev/null @@ -1,19 +0,0 @@ -class Scales: - """The IP-Adapter scales for a patched UNet. This object can be used to dynamically change the scales for a patched - UNet. - """ - - def __init__(self, scales: list[float]): - self._scales = scales - - @property - def scales(self): - return self._scales - - @scales.setter - def scales(self, scales: list[float]): - assert len(scales) == len(self._scales) - self._scales = scales - - def __len__(self): - return len(self._scales) diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/ip_adapter/unet_patcher.py index 76d7dd0c7d..f8c1870f6e 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/ip_adapter/unet_patcher.py @@ -4,46 +4,50 @@ from diffusers.models import UNet2DConditionModel from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.scales import Scales -def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter], scales: Scales): - """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention - weights into them. +class UNetPatcher: + """A class that contains multiple IP-Adapters and can apply them to a UNet.""" - Note that the `unet` param is only used to determine attention block dimensions and naming. - """ - # Construct a dict of attention processors based on the UNet's architecture. - attn_procs = {} - for idx, name in enumerate(unet.attn_processors.keys()): - if name.endswith("attn1.processor"): - attn_procs[name] = AttnProcessor2_0() - else: - # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = IPAttnProcessor2_0( - [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters], scales - ) - return attn_procs + def __init__(self, ip_adapters: list[IPAdapter]): + self._ip_adapters = ip_adapters + self._scales = [1.0] * len(self._ip_adapters) + def set_scale(self, idx: int, value: float): + self._scales[idx] = value -@contextmanager -def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]): - """A context manager that patches `unet` with IP-Adapter attention processors. + def _prepare_attention_processors(self, unet: UNet2DConditionModel): + """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention + weights into them. - Yields: - Scales: The Scales object, which can be used to dynamically alter the scales of the IP-Adapters. - """ - scales = Scales([1.0] * len(ip_adapters)) + Note that the `unet` param is only used to determine attention block dimensions and naming. + """ + # Construct a dict of attention processors based on the UNet's architecture. + attn_procs = {} + for idx, name in enumerate(unet.attn_processors.keys()): + if name.endswith("attn1.processor"): + attn_procs[name] = AttnProcessor2_0() + else: + # Collect the weights from each IP Adapter for the idx'th attention processor. + attn_procs[name] = IPAttnProcessor2_0( + [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], + self._scales, + ) + return attn_procs - attn_procs = _prepare_attention_processors(unet, ip_adapters, scales) + @contextmanager + def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): + """A context manager that patches `unet` with IP-Adapter attention processors.""" - orig_attn_processors = unet.attn_processors + attn_procs = self._prepare_attention_processors(unet) - try: - # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the - # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy - # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. - unet.set_attn_processor(attn_procs) - yield scales - finally: - unet.set_attn_processor(orig_attn_processors) + orig_attn_processors = unet.attn_processors + + try: + # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the + # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy + # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. + unet.set_attn_processor(attn_procs) + yield None + finally: + unet.set_attn_processor(orig_attn_processors) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 3c695e3733..0943b78bf8 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.unet_patcher import Scales, apply_ip_adapter_attention +from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData from ..util import auto_detect_slice_size, normalize_device @@ -425,8 +425,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents, attention_map_saver + ip_adapter_unet_patcher = None if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: - attn_ctx_mgr = self.invokeai_diffuser.custom_attention_context( + attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=conditioning_data.extra, step_count=len(self.scheduler.timesteps), @@ -435,14 +436,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): elif ip_adapter_data is not None: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. - attn_ctx_mgr = apply_ip_adapter_attention( - unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data] - ) + ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) + attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) self.use_ip_adapter = True else: - attn_ctx_mgr = nullcontext() + attn_ctx = nullcontext() - with attn_ctx_mgr as attn_ctx: + with attn_ctx: if callback is not None: callback( PipelineIntermediateState( @@ -467,7 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, - attn_ctx=attn_ctx, + ip_adapter_unet_patcher=ip_adapter_unet_patcher, ) latents = step_output.prev_sample @@ -515,7 +515,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - attn_ctx: Optional[Scales] = None, + ip_adapter_unet_patcher: Optional[UNetPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -538,10 +538,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) if step_index >= first_adapter_step and step_index <= last_adapter_step: # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - attn_ctx.scales[i] = weight + ip_adapter_unet_patcher.set_scale(i, weight) else: # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - attn_ctx.scales[i] = 0.0 + ip_adapter_unet_patcher.set_scale(i, 0.0) # Handle ControlNet(s) and T2I-Adapter(s) down_block_additional_residuals = None diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index f2ca243a93..7f634ee1fe 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -1,7 +1,7 @@ import pytest import torch -from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention +from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.util.test_utils import install_and_load_model @@ -66,7 +66,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): unet.to(torch_device, dtype=torch.float32) cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]} - with apply_ip_adapter_attention(unet, [ip_adapter]): + ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) + with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample assert output.shape == dummy_unet_input["sample"].shape