mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix handling of scales with multiple IP-Adapters.
This commit is contained in:
parent
9403672ac0
commit
d8d0c9af09
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||||
|
from invokeai.backend.ip_adapter.scales import Scales
|
||||||
|
|
||||||
|
|
||||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
||||||
@ -47,13 +48,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weights: list[IPAttentionProcessorWeights]):
|
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: Scales):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
self.weights = weights
|
assert len(weights) == len(scales)
|
||||||
|
|
||||||
|
self._weights = weights
|
||||||
|
self._scales = scales
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -119,9 +123,11 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self.weights)
|
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||||
|
|
||||||
for ipa_embed, ipa_weights in zip(ip_adapter_image_prompt_embeds, self.weights):
|
for ipa_embed, ipa_weights, scale in zip(
|
||||||
|
ip_adapter_image_prompt_embeds, self._weights, self._scales.scales
|
||||||
|
):
|
||||||
# The batch dimensions should match.
|
# The batch dimensions should match.
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
# The channel dimensions should match.
|
# The channel dimensions should match.
|
||||||
@ -144,7 +150,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
hidden_states = hidden_states + ipa_weights.scale * ip_hidden_states
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
@ -8,9 +8,8 @@ class IPAttentionProcessorWeights(torch.nn.Module):
|
|||||||
method.
|
method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_dim: int, out_dim: int, scale: float = 1.0):
|
def __init__(self, in_dim: int, out_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
|
||||||
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
||||||
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
||||||
|
|
||||||
@ -26,11 +25,6 @@ class IPAttentionWeights(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._weights = weights
|
self._weights = weights
|
||||||
|
|
||||||
def set_scale(self, scale: float):
|
|
||||||
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
|
|
||||||
for w in self._weights.values():
|
|
||||||
w.scale = scale
|
|
||||||
|
|
||||||
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
|
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
|
||||||
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
|
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
|
||||||
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
|
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
|
||||||
|
19
invokeai/backend/ip_adapter/scales.py
Normal file
19
invokeai/backend/ip_adapter/scales.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
class Scales:
|
||||||
|
"""The IP-Adapter scales for a patched UNet. This object can be used to dynamically change the scales for a patched
|
||||||
|
UNet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scales: list[float]):
|
||||||
|
self._scales = scales
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scales(self):
|
||||||
|
return self._scales
|
||||||
|
|
||||||
|
@scales.setter
|
||||||
|
def scales(self, scales: list[float]):
|
||||||
|
assert len(scales) == len(self._scales)
|
||||||
|
self._scales = scales
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._scales)
|
@ -4,9 +4,10 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
|
|
||||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.ip_adapter.scales import Scales
|
||||||
|
|
||||||
|
|
||||||
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter], scales: Scales):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
weights into them.
|
weights into them.
|
||||||
|
|
||||||
@ -32,15 +33,22 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
|
|||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = IPAttnProcessor2_0(
|
attn_procs[name] = IPAttnProcessor2_0(
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters]
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters], scales
|
||||||
)
|
)
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
||||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
"""A context manager that patches `unet` with IP-Adapter attention processors.
|
||||||
attn_procs = _prepare_attention_processors(unet, ip_adapters)
|
|
||||||
|
Yields:
|
||||||
|
Scales: The Scales object, which can be used to dynamically alter the scales of the
|
||||||
|
IP-Adapters.
|
||||||
|
"""
|
||||||
|
scales = Scales([1.0] * len(ip_adapters))
|
||||||
|
|
||||||
|
attn_procs = _prepare_attention_processors(unet, ip_adapters, scales)
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
orig_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
@ -49,6 +57,6 @@ def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPA
|
|||||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
||||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
yield None
|
yield scales
|
||||||
finally:
|
finally:
|
||||||
unet.set_attn_processor(orig_attn_processors)
|
unet.set_attn_processor(orig_attn_processors)
|
||||||
|
@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
|
from invokeai.backend.ip_adapter.unet_patcher import Scales, apply_ip_adapter_attention
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
@ -426,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
return latents, attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
|
||||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx_mgr = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=conditioning_data.extra,
|
extra_conditioning_info=conditioning_data.extra,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
@ -435,14 +435,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
elif ip_adapter_data is not None:
|
elif ip_adapter_data is not None:
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
# As it is now, the IP-Adapter will silently be skipped.
|
||||||
attn_ctx = apply_ip_adapter_attention(
|
attn_ctx_mgr = apply_ip_adapter_attention(
|
||||||
unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data]
|
unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data]
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = True
|
self.use_ip_adapter = True
|
||||||
else:
|
else:
|
||||||
attn_ctx = nullcontext()
|
attn_ctx_mgr = nullcontext()
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx_mgr as attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
@ -467,6 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
|
attn_ctx=attn_ctx,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -514,6 +515,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
|
attn_ctx: Optional[Scales] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -526,7 +528,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# handle IP-Adapter
|
# handle IP-Adapter
|
||||||
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
||||||
for single_ip_adapter_data in ip_adapter_data:
|
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
|
||||||
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
|
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
|
||||||
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
|
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
|
||||||
weight = (
|
weight = (
|
||||||
@ -536,10 +538,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||||
single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(weight)
|
attn_ctx.scales[i] = weight
|
||||||
else:
|
else:
|
||||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||||
single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(0.0)
|
attn_ctx.scales[i] = weight
|
||||||
|
|
||||||
# Handle ControlNet(s) and T2I-Adapter(s)
|
# Handle ControlNet(s) and T2I-Adapter(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user