Refactor multi-IP-Adapter to clean up the interface around changing scales.

This commit is contained in:
Ryan Dick 2023-10-06 18:13:35 -04:00 committed by Kent Keirsey
parent 43a3c3c7ea
commit 971ccfb081
6 changed files with 54 additions and 72 deletions

View File

@ -10,7 +10,6 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.adapter import FullAdapterXL, T2IAdapter from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,

View File

@ -9,7 +9,6 @@ import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.ip_adapter.scales import Scales
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict # Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
@ -48,7 +47,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt. the weight scale of image prompt.
""" """
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: Scales): def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
@ -125,9 +124,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights) assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip( for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
ip_adapter_image_prompt_embeds, self._weights, self._scales.scales
):
# The batch dimensions should match. # The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match. # The channel dimensions should match.

View File

@ -1,19 +0,0 @@
class Scales:
"""The IP-Adapter scales for a patched UNet. This object can be used to dynamically change the scales for a patched
UNet.
"""
def __init__(self, scales: list[float]):
self._scales = scales
@property
def scales(self):
return self._scales
@scales.setter
def scales(self, scales: list[float]):
assert len(scales) == len(self._scales)
self._scales = scales
def __len__(self):
return len(self._scales)

View File

@ -4,46 +4,50 @@ from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.scales import Scales
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter], scales: Scales): class UNetPatcher:
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention """A class that contains multiple IP-Adapters and can apply them to a UNet."""
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming. def __init__(self, ip_adapters: list[IPAdapter]):
""" self._ip_adapters = ip_adapters
# Construct a dict of attention processors based on the UNet's architecture. self._scales = [1.0] * len(self._ip_adapters)
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters], scales
)
return attn_procs
def set_scale(self, idx: int, value: float):
self._scales[idx] = value
@contextmanager def _prepare_attention_processors(self, unet: UNet2DConditionModel):
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
"""A context manager that patches `unet` with IP-Adapter attention processors. weights into them.
Yields: Note that the `unet` param is only used to determine attention block dimensions and naming.
Scales: The Scales object, which can be used to dynamically alter the scales of the IP-Adapters. """
""" # Construct a dict of attention processors based on the UNet's architecture.
scales = Scales([1.0] * len(ip_adapters)) attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales,
)
return attn_procs
attn_procs = _prepare_attention_processors(unet, ip_adapters, scales) @contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
orig_attn_processors = unet.attn_processors attn_procs = self._prepare_attention_processors(unet)
try: orig_attn_processors = unet.attn_processors
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy try:
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
unet.set_attn_processor(attn_procs) # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
yield scales # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
finally: unet.set_attn_processor(attn_procs)
unet.set_attn_processor(orig_attn_processors) yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import Scales, apply_ip_adapter_attention from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
@ -425,8 +425,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents, attention_map_saver return latents, attention_map_saver
ip_adapter_unet_patcher = None
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx_mgr = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra, extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps),
@ -435,14 +436,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped. # As it is now, the IP-Adapter will silently be skipped.
attn_ctx_mgr = apply_ip_adapter_attention( ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data] attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
)
self.use_ip_adapter = True self.use_ip_adapter = True
else: else:
attn_ctx_mgr = nullcontext() attn_ctx = nullcontext()
with attn_ctx_mgr as attn_ctx: with attn_ctx:
if callback is not None: if callback is not None:
callback( callback(
PipelineIntermediateState( PipelineIntermediateState(
@ -467,7 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data, t2i_adapter_data=t2i_adapter_data,
attn_ctx=attn_ctx, ip_adapter_unet_patcher=ip_adapter_unet_patcher,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -515,7 +515,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
attn_ctx: Optional[Scales] = None, ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -538,10 +538,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
if step_index >= first_adapter_step and step_index <= last_adapter_step: if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
attn_ctx.scales[i] = weight ip_adapter_unet_patcher.set_scale(i, weight)
else: else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect. # Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
attn_ctx.scales[i] = 0.0 ip_adapter_unet_patcher.set_scale(i, 0.0)
# Handle ControlNet(s) and T2I-Adapter(s) # Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None down_block_additional_residuals = None

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.util.test_utils import install_and_load_model from invokeai.backend.util.test_utils import install_and_load_model
@ -66,7 +66,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
unet.to(torch_device, dtype=torch.float32) unet.to(torch_device, dtype=torch.float32)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]} cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
with apply_ip_adapter_attention(unet, [ip_adapter]): ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
assert output.shape == dummy_unet_input["sample"].shape assert output.shape == dummy_unet_input["sample"].shape