mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pass IP-Adapter scales through the cross_attn_kwargs pathway, since they are the same for all attention layers. This change also helps to prepare for adding IP-Adapter region masks.
This commit is contained in:
parent
babdc64b17
commit
2e27ed5f3d
@ -56,6 +56,7 @@ from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_sea
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
IPAdapterData,
|
||||
Range,
|
||||
SDXLConditioningInfo,
|
||||
TextConditioningData,
|
||||
@ -66,7 +67,6 @@ from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
|
@ -21,9 +21,8 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterConditioningInfo,
|
||||
IPAdapterData,
|
||||
TextConditioningData,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
@ -152,17 +151,6 @@ class ControlNetData:
|
||||
resize_mode: str = Field(default="just_resize")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2IAdapterData:
|
||||
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
|
||||
@ -439,7 +427,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
unet_attention_patcher=unet_attention_patcher,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
@ -471,7 +458,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@ -486,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# handle IP-Adapter
|
||||
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
||||
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
|
||||
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
|
||||
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
|
||||
weight = (
|
||||
single_ip_adapter_data.weight[step_index]
|
||||
if isinstance(single_ip_adapter_data.weight, List)
|
||||
else single_ip_adapter_data.weight
|
||||
)
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
unet_attention_patcher.set_scale(i, weight)
|
||||
else:
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
unet_attention_patcher.set_scale(i, 0.0)
|
||||
|
||||
# Handle ControlNet(s)
|
||||
down_block_additional_residuals = None
|
||||
mid_block_additional_residual = None
|
||||
@ -545,17 +514,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
ip_adapter_conditioning = None
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
|
@ -1,8 +1,11 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
@ -45,6 +48,27 @@ class IPAdapterConditioningInfo:
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
begin_step_percent: float = 0.0
|
||||
end_step_percent: float = 1.0
|
||||
|
||||
def scale_for_step(self, step_index: int, total_steps: int) -> float:
|
||||
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
|
||||
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
|
||||
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
return weight
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
return 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
start: int
|
||||
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@ -54,15 +55,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
# For IP-Adapter:
|
||||
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||
apply regional prompt masking.
|
||||
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
|
||||
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
|
||||
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
regional_ip_data: The IP-Adapter data for the current batch.
|
||||
"""
|
||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
@ -141,9 +140,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||
if self._is_ip_adapter_enabled():
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert regional_ip_data is not None
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||
regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
@ -178,8 +177,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
|
||||
assert ip_adapter_image_prompt_embeds is None
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
|
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
|
||||
class RegionalIPData:
|
||||
"""A class to manage the data for regional IP-Adapter conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_prompt_embeds: list[torch.Tensor],
|
||||
scales: list[float],
|
||||
):
|
||||
"""Initialize a `IPAdapterConditioningData` object."""
|
||||
# The image prompt embeddings.
|
||||
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
||||
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
self.image_prompt_embeds = image_prompt_embeds
|
||||
|
||||
# The scales for the IP-Adapter attention.
|
||||
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
||||
self.scales = scales
|
@ -8,11 +8,12 @@ from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterConditioningInfo,
|
||||
IPAdapterData,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
@ -169,15 +170,13 @@ class InvokeAIDiffuserComponent:
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
percent_through = step_index / total_step_count
|
||||
|
||||
if self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
@ -186,8 +185,9 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
percent_through=percent_through,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@ -200,8 +200,9 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
percent_through=percent_through,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@ -260,15 +261,16 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
percent_through: float,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
@ -276,12 +278,16 @@ class InvokeAIDiffuserComponent:
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
if ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
image_prompt_embeds = [
|
||||
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
@ -326,7 +332,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
@ -350,8 +356,9 @@ class InvokeAIDiffuserComponent:
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
percent_through: float,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@ -388,13 +395,18 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||
if ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
image_prompt_embeds = [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
@ -408,7 +420,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
@ -428,14 +440,18 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||
if ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
image_prompt_embeds = [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
@ -449,7 +465,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
|
@ -17,9 +17,6 @@ class UNetAttentionPatcher:
|
||||
if self._ip_adapters is not None:
|
||||
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||
|
||||
def set_scale(self, idx: int, value: float):
|
||||
self._ip_adapter_scales[idx] = value
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
weights into them (if IP-Adapters are being applied).
|
||||
|
Loading…
Reference in New Issue
Block a user