2023-07-16 15:08:38 +00:00
from typing import Any , Dict , List , Optional , Tuple , Union
import torch
from torch import nn
from diffusers . configuration_utils import ConfigMixin , register_to_config
2023-08-14 12:18:54 +00:00
from diffusers . loaders import FromOriginalControlnetMixin
2023-07-16 15:08:38 +00:00
from diffusers . models . attention_processor import AttentionProcessor , AttnProcessor
2023-08-17 19:40:15 +00:00
from diffusers . models . embeddings import (
TextImageProjection ,
TextImageTimeEmbedding ,
TextTimeEmbedding ,
TimestepEmbedding ,
Timesteps ,
)
2023-07-16 15:08:38 +00:00
from diffusers . models . modeling_utils import ModelMixin
from diffusers . models . unet_2d_blocks import (
CrossAttnDownBlock2D ,
DownBlock2D ,
UNetMidBlock2DCrossAttn ,
get_down_block ,
)
from diffusers . models . unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers . models . controlnet import ControlNetConditioningEmbedding , ControlNetOutput , zero_module
2023-08-17 22:45:25 +00:00
from invokeai . backend . util . logging import InvokeAILogger
2023-08-14 12:18:54 +00:00
# TODO: create PR to diffusers
2023-07-16 15:08:38 +00:00
# Modified ControlNetModel with encoder_attention_mask argument added
2023-08-17 19:40:15 +00:00
2023-08-17 22:45:25 +00:00
logger = InvokeAILogger . getLogger ( __name__ )
2023-08-14 12:18:54 +00:00
class ControlNetModel ( ModelMixin , ConfigMixin , FromOriginalControlnetMixin ) :
2023-07-16 15:08:38 +00:00
"""
A ControlNet model .
Args :
in_channels ( ` int ` , defaults to 4 ) :
The number of channels in the input sample .
flip_sin_to_cos ( ` bool ` , defaults to ` True ` ) :
Whether to flip the sin to cos in the time embedding .
freq_shift ( ` int ` , defaults to 0 ) :
The frequency shift to apply to the time embedding .
down_block_types ( ` tuple [ str ] ` , defaults to ` ( " CrossAttnDownBlock2D " , " CrossAttnDownBlock2D " , " CrossAttnDownBlock2D " , " DownBlock2D " ) ` ) :
The tuple of downsample blocks to use .
only_cross_attention ( ` Union [ bool , Tuple [ bool ] ] ` , defaults to ` False ` ) :
block_out_channels ( ` tuple [ int ] ` , defaults to ` ( 320 , 640 , 1280 , 1280 ) ` ) :
The tuple of output channels for each block .
layers_per_block ( ` int ` , defaults to 2 ) :
The number of layers per block .
downsample_padding ( ` int ` , defaults to 1 ) :
The padding to use for the downsampling convolution .
mid_block_scale_factor ( ` float ` , defaults to 1 ) :
The scale factor to use for the mid block .
act_fn ( ` str ` , defaults to " silu " ) :
The activation function to use .
norm_num_groups ( ` int ` , * optional * , defaults to 32 ) :
The number of groups to use for the normalization . If None , normalization and activation layers is skipped
in post - processing .
norm_eps ( ` float ` , defaults to 1e-5 ) :
The epsilon to use for the normalization .
cross_attention_dim ( ` int ` , defaults to 1280 ) :
The dimension of the cross attention features .
2023-08-14 12:18:54 +00:00
transformer_layers_per_block ( ` int ` or ` Tuple [ int ] ` , * optional * , defaults to 1 ) :
The number of transformer blocks of type [ ` ~ models . attention . BasicTransformerBlock ` ] . Only relevant for
[ ` ~ models . unet_2d_blocks . CrossAttnDownBlock2D ` ] , [ ` ~ models . unet_2d_blocks . CrossAttnUpBlock2D ` ] ,
[ ` ~ models . unet_2d_blocks . UNetMidBlock2DCrossAttn ` ] .
encoder_hid_dim ( ` int ` , * optional * , defaults to None ) :
If ` encoder_hid_dim_type ` is defined , ` encoder_hidden_states ` will be projected from ` encoder_hid_dim `
dimension to ` cross_attention_dim ` .
encoder_hid_dim_type ( ` str ` , * optional * , defaults to ` None ` ) :
If given , the ` encoder_hidden_states ` and potentially other embeddings are down - projected to text
embeddings of dimension ` cross_attention ` according to ` encoder_hid_dim_type ` .
2023-07-16 15:08:38 +00:00
attention_head_dim ( ` Union [ int , Tuple [ int ] ] ` , defaults to 8 ) :
The dimension of the attention heads .
use_linear_projection ( ` bool ` , defaults to ` False ` ) :
class_embed_type ( ` str ` , * optional * , defaults to ` None ` ) :
The type of class embedding to use which is ultimately summed with the time embeddings . Choose from None ,
` " timestep " ` , ` " identity " ` , ` " projection " ` , or ` " simple_projection " ` .
2023-08-14 12:18:54 +00:00
addition_embed_type ( ` str ` , * optional * , defaults to ` None ` ) :
Configures an optional embedding which will be summed with the time embeddings . Choose from ` None ` or
" text " . " text " will use the ` TextTimeEmbedding ` layer .
2023-07-16 15:08:38 +00:00
num_class_embeds ( ` int ` , * optional * , defaults to 0 ) :
Input dimension of the learnable embedding matrix to be projected to ` time_embed_dim ` , when performing
class conditioning with ` class_embed_type ` equal to ` None ` .
upcast_attention ( ` bool ` , defaults to ` False ` ) :
resnet_time_scale_shift ( ` str ` , defaults to ` " default " ` ) :
Time scale shift config for ResNet blocks ( see ` ResnetBlock2D ` ) . Choose from ` default ` or ` scale_shift ` .
projection_class_embeddings_input_dim ( ` int ` , * optional * , defaults to ` None ` ) :
The dimension of the ` class_labels ` input when ` class_embed_type = " projection " ` . Required when
` class_embed_type = " projection " ` .
controlnet_conditioning_channel_order ( ` str ` , defaults to ` " rgb " ` ) :
The channel order of conditional image . Will convert to ` rgb ` if it ' s `bgr`.
conditioning_embedding_out_channels ( ` tuple [ int ] ` , * optional * , defaults to ` ( 16 , 32 , 96 , 256 ) ` ) :
The tuple of output channel for each block in the ` conditioning_embedding ` layer .
global_pool_conditions ( ` bool ` , defaults to ` False ` ) :
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__ (
self ,
in_channels : int = 4 ,
conditioning_channels : int = 3 ,
flip_sin_to_cos : bool = True ,
freq_shift : int = 0 ,
down_block_types : Tuple [ str ] = (
" CrossAttnDownBlock2D " ,
" CrossAttnDownBlock2D " ,
" CrossAttnDownBlock2D " ,
" DownBlock2D " ,
) ,
only_cross_attention : Union [ bool , Tuple [ bool ] ] = False ,
2023-08-17 22:45:25 +00:00
block_out_channels : Tuple [ int , . . . ] = ( 320 , 640 , 1280 , 1280 ) ,
2023-07-16 15:08:38 +00:00
layers_per_block : int = 2 ,
downsample_padding : int = 1 ,
mid_block_scale_factor : float = 1 ,
act_fn : str = " silu " ,
norm_num_groups : Optional [ int ] = 32 ,
norm_eps : float = 1e-5 ,
cross_attention_dim : int = 1280 ,
2023-08-14 12:18:54 +00:00
transformer_layers_per_block : Union [ int , Tuple [ int ] ] = 1 ,
encoder_hid_dim : Optional [ int ] = None ,
encoder_hid_dim_type : Optional [ str ] = None ,
2023-07-16 15:08:38 +00:00
attention_head_dim : Union [ int , Tuple [ int ] ] = 8 ,
num_attention_heads : Optional [ Union [ int , Tuple [ int ] ] ] = None ,
use_linear_projection : bool = False ,
class_embed_type : Optional [ str ] = None ,
2023-08-14 12:18:54 +00:00
addition_embed_type : Optional [ str ] = None ,
addition_time_embed_dim : Optional [ int ] = None ,
2023-07-16 15:08:38 +00:00
num_class_embeds : Optional [ int ] = None ,
upcast_attention : bool = False ,
resnet_time_scale_shift : str = " default " ,
projection_class_embeddings_input_dim : Optional [ int ] = None ,
controlnet_conditioning_channel_order : str = " rgb " ,
conditioning_embedding_out_channels : Optional [ Tuple [ int ] ] = ( 16 , 32 , 96 , 256 ) ,
global_pool_conditions : bool = False ,
2023-08-14 12:18:54 +00:00
addition_embed_type_num_heads = 64 ,
2023-07-16 15:08:38 +00:00
) :
super ( ) . __init__ ( )
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len ( block_out_channels ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: { block_out_channels } . `down_block_types`: { down_block_types } . "
)
if not isinstance ( only_cross_attention , bool ) and len ( only_cross_attention ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: { only_cross_attention } . `down_block_types`: { down_block_types } . "
)
if not isinstance ( num_attention_heads , int ) and len ( num_attention_heads ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: { num_attention_heads } . `down_block_types`: { down_block_types } . "
)
2023-08-14 12:18:54 +00:00
if isinstance ( transformer_layers_per_block , int ) :
transformer_layers_per_block = [ transformer_layers_per_block ] * len ( down_block_types )
2023-07-16 15:08:38 +00:00
# input
conv_in_kernel = 3
conv_in_padding = ( conv_in_kernel - 1 ) / / 2
self . conv_in = nn . Conv2d (
in_channels , block_out_channels [ 0 ] , kernel_size = conv_in_kernel , padding = conv_in_padding
)
# time
time_embed_dim = block_out_channels [ 0 ] * 4
self . time_proj = Timesteps ( block_out_channels [ 0 ] , flip_sin_to_cos , freq_shift )
timestep_input_dim = block_out_channels [ 0 ]
self . time_embedding = TimestepEmbedding (
timestep_input_dim ,
time_embed_dim ,
act_fn = act_fn ,
)
2023-08-14 12:18:54 +00:00
if encoder_hid_dim_type is None and encoder_hid_dim is not None :
encoder_hid_dim_type = " text_proj "
self . register_to_config ( encoder_hid_dim_type = encoder_hid_dim_type )
logger . info ( " encoder_hid_dim_type defaults to ' text_proj ' as `encoder_hid_dim` is defined. " )
if encoder_hid_dim is None and encoder_hid_dim_type is not None :
raise ValueError (
f " `encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to { encoder_hid_dim_type } . "
)
if encoder_hid_dim_type == " text_proj " :
self . encoder_hid_proj = nn . Linear ( encoder_hid_dim , cross_attention_dim )
elif encoder_hid_dim_type == " text_image_proj " :
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self . encoder_hid_proj = TextImageProjection (
text_embed_dim = encoder_hid_dim ,
image_embed_dim = cross_attention_dim ,
cross_attention_dim = cross_attention_dim ,
)
elif encoder_hid_dim_type is not None :
raise ValueError (
f " encoder_hid_dim_type: { encoder_hid_dim_type } must be None, ' text_proj ' or ' text_image_proj ' . "
)
else :
self . encoder_hid_proj = None
2023-07-16 15:08:38 +00:00
# class embedding
if class_embed_type is None and num_class_embeds is not None :
self . class_embedding = nn . Embedding ( num_class_embeds , time_embed_dim )
elif class_embed_type == " timestep " :
self . class_embedding = TimestepEmbedding ( timestep_input_dim , time_embed_dim )
elif class_embed_type == " identity " :
self . class_embedding = nn . Identity ( time_embed_dim , time_embed_dim )
elif class_embed_type == " projection " :
if projection_class_embeddings_input_dim is None :
raise ValueError (
" `class_embed_type`: ' projection ' requires `projection_class_embeddings_input_dim` be set "
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self . class_embedding = TimestepEmbedding ( projection_class_embeddings_input_dim , time_embed_dim )
else :
self . class_embedding = None
2023-08-14 12:18:54 +00:00
if addition_embed_type == " text " :
if encoder_hid_dim is not None :
text_time_embedding_from_dim = encoder_hid_dim
else :
text_time_embedding_from_dim = cross_attention_dim
self . add_embedding = TextTimeEmbedding (
text_time_embedding_from_dim , time_embed_dim , num_heads = addition_embed_type_num_heads
)
elif addition_embed_type == " text_image " :
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self . add_embedding = TextImageTimeEmbedding (
text_embed_dim = cross_attention_dim , image_embed_dim = cross_attention_dim , time_embed_dim = time_embed_dim
)
elif addition_embed_type == " text_time " :
self . add_time_proj = Timesteps ( addition_time_embed_dim , flip_sin_to_cos , freq_shift )
self . add_embedding = TimestepEmbedding ( projection_class_embeddings_input_dim , time_embed_dim )
elif addition_embed_type is not None :
raise ValueError ( f " addition_embed_type: { addition_embed_type } must be None, ' text ' or ' text_image ' . " )
2023-07-16 15:08:38 +00:00
# control net conditioning embedding
self . controlnet_cond_embedding = ControlNetConditioningEmbedding (
conditioning_embedding_channels = block_out_channels [ 0 ] ,
block_out_channels = conditioning_embedding_out_channels ,
conditioning_channels = conditioning_channels ,
)
self . down_blocks = nn . ModuleList ( [ ] )
self . controlnet_down_blocks = nn . ModuleList ( [ ] )
if isinstance ( only_cross_attention , bool ) :
only_cross_attention = [ only_cross_attention ] * len ( down_block_types )
if isinstance ( attention_head_dim , int ) :
attention_head_dim = ( attention_head_dim , ) * len ( down_block_types )
if isinstance ( num_attention_heads , int ) :
num_attention_heads = ( num_attention_heads , ) * len ( down_block_types )
# down
output_channel = block_out_channels [ 0 ]
controlnet_block = nn . Conv2d ( output_channel , output_channel , kernel_size = 1 )
controlnet_block = zero_module ( controlnet_block )
self . controlnet_down_blocks . append ( controlnet_block )
for i , down_block_type in enumerate ( down_block_types ) :
input_channel = output_channel
output_channel = block_out_channels [ i ]
is_final_block = i == len ( block_out_channels ) - 1
down_block = get_down_block (
down_block_type ,
num_layers = layers_per_block ,
2023-08-14 12:18:54 +00:00
transformer_layers_per_block = transformer_layers_per_block [ i ] ,
2023-07-16 15:08:38 +00:00
in_channels = input_channel ,
out_channels = output_channel ,
temb_channels = time_embed_dim ,
add_downsample = not is_final_block ,
resnet_eps = norm_eps ,
resnet_act_fn = act_fn ,
resnet_groups = norm_num_groups ,
cross_attention_dim = cross_attention_dim ,
num_attention_heads = num_attention_heads [ i ] ,
attention_head_dim = attention_head_dim [ i ] if attention_head_dim [ i ] is not None else output_channel ,
downsample_padding = downsample_padding ,
use_linear_projection = use_linear_projection ,
only_cross_attention = only_cross_attention [ i ] ,
upcast_attention = upcast_attention ,
resnet_time_scale_shift = resnet_time_scale_shift ,
)
self . down_blocks . append ( down_block )
for _ in range ( layers_per_block ) :
controlnet_block = nn . Conv2d ( output_channel , output_channel , kernel_size = 1 )
controlnet_block = zero_module ( controlnet_block )
self . controlnet_down_blocks . append ( controlnet_block )
if not is_final_block :
controlnet_block = nn . Conv2d ( output_channel , output_channel , kernel_size = 1 )
controlnet_block = zero_module ( controlnet_block )
self . controlnet_down_blocks . append ( controlnet_block )
# mid
mid_block_channel = block_out_channels [ - 1 ]
controlnet_block = nn . Conv2d ( mid_block_channel , mid_block_channel , kernel_size = 1 )
controlnet_block = zero_module ( controlnet_block )
self . controlnet_mid_block = controlnet_block
self . mid_block = UNetMidBlock2DCrossAttn (
2023-08-14 12:18:54 +00:00
transformer_layers_per_block = transformer_layers_per_block [ - 1 ] ,
2023-07-16 15:08:38 +00:00
in_channels = mid_block_channel ,
temb_channels = time_embed_dim ,
resnet_eps = norm_eps ,
resnet_act_fn = act_fn ,
output_scale_factor = mid_block_scale_factor ,
resnet_time_scale_shift = resnet_time_scale_shift ,
cross_attention_dim = cross_attention_dim ,
num_attention_heads = num_attention_heads [ - 1 ] ,
resnet_groups = norm_num_groups ,
use_linear_projection = use_linear_projection ,
upcast_attention = upcast_attention ,
)
@classmethod
def from_unet (
cls ,
unet : UNet2DConditionModel ,
controlnet_conditioning_channel_order : str = " rgb " ,
conditioning_embedding_out_channels : Optional [ Tuple [ int ] ] = ( 16 , 32 , 96 , 256 ) ,
load_weights_from_unet : bool = True ,
) :
r """
Instantiate a [ ` ControlNetModel ` ] from [ ` UNet2DConditionModel ` ] .
Parameters :
unet ( ` UNet2DConditionModel ` ) :
The UNet model weights to copy to the [ ` ControlNetModel ` ] . All configuration options are also copied
where applicable .
"""
2023-08-14 12:18:54 +00:00
transformer_layers_per_block = (
unet . config . transformer_layers_per_block if " transformer_layers_per_block " in unet . config else 1
)
encoder_hid_dim = unet . config . encoder_hid_dim if " encoder_hid_dim " in unet . config else None
encoder_hid_dim_type = unet . config . encoder_hid_dim_type if " encoder_hid_dim_type " in unet . config else None
addition_embed_type = unet . config . addition_embed_type if " addition_embed_type " in unet . config else None
addition_time_embed_dim = (
unet . config . addition_time_embed_dim if " addition_time_embed_dim " in unet . config else None
)
2023-07-16 15:08:38 +00:00
controlnet = cls (
2023-08-14 12:18:54 +00:00
encoder_hid_dim = encoder_hid_dim ,
encoder_hid_dim_type = encoder_hid_dim_type ,
addition_embed_type = addition_embed_type ,
addition_time_embed_dim = addition_time_embed_dim ,
transformer_layers_per_block = transformer_layers_per_block ,
2023-07-16 15:08:38 +00:00
in_channels = unet . config . in_channels ,
flip_sin_to_cos = unet . config . flip_sin_to_cos ,
freq_shift = unet . config . freq_shift ,
down_block_types = unet . config . down_block_types ,
only_cross_attention = unet . config . only_cross_attention ,
block_out_channels = unet . config . block_out_channels ,
layers_per_block = unet . config . layers_per_block ,
downsample_padding = unet . config . downsample_padding ,
mid_block_scale_factor = unet . config . mid_block_scale_factor ,
act_fn = unet . config . act_fn ,
norm_num_groups = unet . config . norm_num_groups ,
norm_eps = unet . config . norm_eps ,
cross_attention_dim = unet . config . cross_attention_dim ,
attention_head_dim = unet . config . attention_head_dim ,
num_attention_heads = unet . config . num_attention_heads ,
use_linear_projection = unet . config . use_linear_projection ,
class_embed_type = unet . config . class_embed_type ,
num_class_embeds = unet . config . num_class_embeds ,
upcast_attention = unet . config . upcast_attention ,
resnet_time_scale_shift = unet . config . resnet_time_scale_shift ,
projection_class_embeddings_input_dim = unet . config . projection_class_embeddings_input_dim ,
controlnet_conditioning_channel_order = controlnet_conditioning_channel_order ,
conditioning_embedding_out_channels = conditioning_embedding_out_channels ,
)
if load_weights_from_unet :
controlnet . conv_in . load_state_dict ( unet . conv_in . state_dict ( ) )
controlnet . time_proj . load_state_dict ( unet . time_proj . state_dict ( ) )
controlnet . time_embedding . load_state_dict ( unet . time_embedding . state_dict ( ) )
if controlnet . class_embedding :
controlnet . class_embedding . load_state_dict ( unet . class_embedding . state_dict ( ) )
controlnet . down_blocks . load_state_dict ( unet . down_blocks . state_dict ( ) )
controlnet . mid_block . load_state_dict ( unet . mid_block . state_dict ( ) )
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors ( self ) - > Dict [ str , AttentionProcessor ] :
r """
Returns :
` dict ` of attention processors : A dictionary containing all attention processors used in the model with
indexed by its weight name .
"""
# set recursively
processors = { }
def fn_recursive_add_processors ( name : str , module : torch . nn . Module , processors : Dict [ str , AttentionProcessor ] ) :
if hasattr ( module , " set_processor " ) :
processors [ f " { name } .processor " ] = module . processor
for sub_name , child in module . named_children ( ) :
fn_recursive_add_processors ( f " { name } . { sub_name } " , child , processors )
return processors
for name , module in self . named_children ( ) :
fn_recursive_add_processors ( name , module , processors )
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor ( self , processor : Union [ AttentionProcessor , Dict [ str , AttentionProcessor ] ] ) :
r """
Sets the attention processor to use to compute attention .
Parameters :
processor ( ` dict ` of ` AttentionProcessor ` or only ` AttentionProcessor ` ) :
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for * * all * * ` Attention ` layers .
If ` processor ` is a dict , the key needs to define the path to the corresponding cross attention
processor . This is strongly recommended when setting trainable attention processors .
"""
count = len ( self . attn_processors . keys ( ) )
if isinstance ( processor , dict ) and len ( processor ) != count :
raise ValueError (
f " A dict of processors was passed, but the number of processors { len ( processor ) } does not match the "
f " number of attention layers: { count } . Please make sure to pass { count } processor classes. "
)
def fn_recursive_attn_processor ( name : str , module : torch . nn . Module , processor ) :
if hasattr ( module , " set_processor " ) :
if not isinstance ( processor , dict ) :
module . set_processor ( processor )
else :
module . set_processor ( processor . pop ( f " { name } .processor " ) )
for sub_name , child in module . named_children ( ) :
fn_recursive_attn_processor ( f " { name } . { sub_name } " , child , processor )
for name , module in self . named_children ( ) :
fn_recursive_attn_processor ( name , module , processor )
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor ( self ) :
"""
Disables custom attention processors and sets the default attention implementation .
"""
self . set_attn_processor ( AttnProcessor ( ) )
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice ( self , slice_size ) :
r """
Enable sliced attention computation .
When this option is enabled , the attention module splits the input tensor in slices to compute attention in
several steps . This is useful for saving some memory in exchange for a small decrease in speed .
Args :
slice_size ( ` str ` or ` int ` or ` list ( int ) ` , * optional * , defaults to ` " auto " ` ) :
When ` " auto " ` , input to the attention heads is halved , so attention is computed in two steps . If
` " max " ` , maximum amount of memory is saved by running only one slice at a time . If a number is
provided , uses as many slices as ` attention_head_dim / / slice_size ` . In this case , ` attention_head_dim `
must be a multiple of ` slice_size ` .
"""
sliceable_head_dims = [ ]
def fn_recursive_retrieve_sliceable_dims ( module : torch . nn . Module ) :
if hasattr ( module , " set_attention_slice " ) :
sliceable_head_dims . append ( module . sliceable_head_dim )
for child in module . children ( ) :
fn_recursive_retrieve_sliceable_dims ( child )
# retrieve number of attention layers
for module in self . children ( ) :
fn_recursive_retrieve_sliceable_dims ( module )
num_sliceable_layers = len ( sliceable_head_dims )
if slice_size == " auto " :
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [ dim / / 2 for dim in sliceable_head_dims ]
elif slice_size == " max " :
# make smallest slice possible
slice_size = num_sliceable_layers * [ 1 ]
slice_size = num_sliceable_layers * [ slice_size ] if not isinstance ( slice_size , list ) else slice_size
if len ( slice_size ) != len ( sliceable_head_dims ) :
raise ValueError (
f " You have provided { len ( slice_size ) } , but { self . config } has { len ( sliceable_head_dims ) } different "
f " attention layers. Make sure to match `len(slice_size)` to be { len ( sliceable_head_dims ) } . "
)
for i in range ( len ( slice_size ) ) :
size = slice_size [ i ]
dim = sliceable_head_dims [ i ]
if size is not None and size > dim :
raise ValueError ( f " size { size } has to be smaller or equal to { dim } . " )
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice ( module : torch . nn . Module , slice_size : List [ int ] ) :
if hasattr ( module , " set_attention_slice " ) :
module . set_attention_slice ( slice_size . pop ( ) )
for child in module . children ( ) :
fn_recursive_set_attention_slice ( child , slice_size )
reversed_slice_size = list ( reversed ( slice_size ) )
for module in self . children ( ) :
fn_recursive_set_attention_slice ( module , reversed_slice_size )
def _set_gradient_checkpointing ( self , module , value = False ) :
if isinstance ( module , ( CrossAttnDownBlock2D , DownBlock2D ) ) :
module . gradient_checkpointing = value
def forward (
self ,
sample : torch . FloatTensor ,
timestep : Union [ torch . Tensor , float , int ] ,
encoder_hidden_states : torch . Tensor ,
controlnet_cond : torch . FloatTensor ,
conditioning_scale : float = 1.0 ,
class_labels : Optional [ torch . Tensor ] = None ,
timestep_cond : Optional [ torch . Tensor ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
2023-08-14 12:18:54 +00:00
added_cond_kwargs : Optional [ Dict [ str , torch . Tensor ] ] = None ,
2023-07-16 15:08:38 +00:00
cross_attention_kwargs : Optional [ Dict [ str , Any ] ] = None ,
encoder_attention_mask : Optional [ torch . Tensor ] = None ,
guess_mode : bool = False ,
return_dict : bool = True ,
) - > Union [ ControlNetOutput , Tuple ] :
"""
The [ ` ControlNetModel ` ] forward method .
Args :
sample ( ` torch . FloatTensor ` ) :
The noisy input tensor .
timestep ( ` Union [ torch . Tensor , float , int ] ` ) :
The number of timesteps to denoise an input .
encoder_hidden_states ( ` torch . Tensor ` ) :
The encoder hidden states .
controlnet_cond ( ` torch . FloatTensor ` ) :
The conditional input tensor of shape ` ( batch_size , sequence_length , hidden_size ) ` .
conditioning_scale ( ` float ` , defaults to ` 1.0 ` ) :
The scale factor for ControlNet outputs .
class_labels ( ` torch . Tensor ` , * optional * , defaults to ` None ` ) :
Optional class labels for conditioning . Their embeddings will be summed with the timestep embeddings .
timestep_cond ( ` torch . Tensor ` , * optional * , defaults to ` None ` ) :
attention_mask ( ` torch . Tensor ` , * optional * , defaults to ` None ` ) :
2023-08-14 12:18:54 +00:00
added_cond_kwargs ( ` dict ` ) :
Additional conditions for the Stable Diffusion XL UNet .
cross_attention_kwargs ( ` dict [ str ] ` , * optional * , defaults to ` None ` ) :
2023-07-16 15:08:38 +00:00
A kwargs dictionary that if specified is passed along to the ` AttnProcessor ` .
encoder_attention_mask ( ` torch . Tensor ` ) :
A cross - attention mask of shape ` ( batch , sequence_length ) ` is applied to ` encoder_hidden_states ` . If
` True ` the mask is kept , otherwise if ` False ` it is discarded . Mask will be converted into a bias ,
which adds large negative values to the attention scores corresponding to " discard " tokens .
guess_mode ( ` bool ` , defaults to ` False ` ) :
In this mode , the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts . A ` guidance_scale ` between 3.0 and 5.0 is recommended .
return_dict ( ` bool ` , defaults to ` True ` ) :
Whether or not to return a [ ` ~ models . controlnet . ControlNetOutput ` ] instead of a plain tuple .
Returns :
[ ` ~ models . controlnet . ControlNetOutput ` ] * * or * * ` tuple ` :
If ` return_dict ` is ` True ` , a [ ` ~ models . controlnet . ControlNetOutput ` ] is returned , otherwise a tuple is
returned where the first element is the sample tensor .
"""
# check channel order
channel_order = self . config . controlnet_conditioning_channel_order
if channel_order == " rgb " :
# in rgb order by default
. . .
elif channel_order == " bgr " :
controlnet_cond = torch . flip ( controlnet_cond , dims = [ 1 ] )
else :
raise ValueError ( f " unknown `controlnet_conditioning_channel_order`: { channel_order } " )
# prepare attention_mask
if attention_mask is not None :
attention_mask = ( 1 - attention_mask . to ( sample . dtype ) ) * - 10000.0
attention_mask = attention_mask . unsqueeze ( 1 )
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None :
encoder_attention_mask = ( 1 - encoder_attention_mask . to ( sample . dtype ) ) * - 10000.0
encoder_attention_mask = encoder_attention_mask . unsqueeze ( 1 )
# 1. time
timesteps = timestep
if not torch . is_tensor ( timesteps ) :
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample . device . type == " mps "
if isinstance ( timestep , float ) :
dtype = torch . float32 if is_mps else torch . float64
else :
dtype = torch . int32 if is_mps else torch . int64
timesteps = torch . tensor ( [ timesteps ] , dtype = dtype , device = sample . device )
elif len ( timesteps . shape ) == 0 :
timesteps = timesteps [ None ] . to ( sample . device )
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps . expand ( sample . shape [ 0 ] )
t_emb = self . time_proj ( timesteps )
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb . to ( dtype = sample . dtype )
emb = self . time_embedding ( t_emb , timestep_cond )
2023-08-14 12:18:54 +00:00
aug_emb = None
2023-07-16 15:08:38 +00:00
if self . class_embedding is not None :
if class_labels is None :
raise ValueError ( " class_labels should be provided when num_class_embeds > 0 " )
if self . config . class_embed_type == " timestep " :
class_labels = self . time_proj ( class_labels )
class_emb = self . class_embedding ( class_labels ) . to ( dtype = self . dtype )
emb = emb + class_emb
2023-08-14 12:18:54 +00:00
if " addition_embed_type " in self . config :
if self . config . addition_embed_type == " text " :
aug_emb = self . add_embedding ( encoder_hidden_states )
elif self . config . addition_embed_type == " text_time " :
if " text_embeds " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' text_time ' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs` "
)
text_embeds = added_cond_kwargs . get ( " text_embeds " )
if " time_ids " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' text_time ' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs` "
)
time_ids = added_cond_kwargs . get ( " time_ids " )
time_embeds = self . add_time_proj ( time_ids . flatten ( ) )
time_embeds = time_embeds . reshape ( ( text_embeds . shape [ 0 ] , - 1 ) )
add_embeds = torch . concat ( [ text_embeds , time_embeds ] , dim = - 1 )
add_embeds = add_embeds . to ( emb . dtype )
aug_emb = self . add_embedding ( add_embeds )
emb = emb + aug_emb if aug_emb is not None else emb
2023-07-16 15:08:38 +00:00
# 2. pre-process
sample = self . conv_in ( sample )
controlnet_cond = self . controlnet_cond_embedding ( controlnet_cond )
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = ( sample , )
for downsample_block in self . down_blocks :
if hasattr ( downsample_block , " has_cross_attention " ) and downsample_block . has_cross_attention :
sample , res_samples = downsample_block (
hidden_states = sample ,
temb = emb ,
encoder_hidden_states = encoder_hidden_states ,
attention_mask = attention_mask ,
cross_attention_kwargs = cross_attention_kwargs ,
encoder_attention_mask = encoder_attention_mask ,
)
else :
sample , res_samples = downsample_block ( hidden_states = sample , temb = emb )
down_block_res_samples + = res_samples
# 4. mid
if self . mid_block is not None :
sample = self . mid_block (
sample ,
emb ,
encoder_hidden_states = encoder_hidden_states ,
attention_mask = attention_mask ,
cross_attention_kwargs = cross_attention_kwargs ,
encoder_attention_mask = encoder_attention_mask ,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ( )
for down_block_res_sample , controlnet_block in zip ( down_block_res_samples , self . controlnet_down_blocks ) :
down_block_res_sample = controlnet_block ( down_block_res_sample )
controlnet_down_block_res_samples = controlnet_down_block_res_samples + ( down_block_res_sample , )
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self . controlnet_mid_block ( sample )
# 6. scaling
if guess_mode and not self . config . global_pool_conditions :
scales = torch . logspace ( - 1 , 0 , len ( down_block_res_samples ) + 1 , device = sample . device ) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [ sample * scale for sample , scale in zip ( down_block_res_samples , scales ) ]
mid_block_res_sample = mid_block_res_sample * scales [ - 1 ] # last one
else :
down_block_res_samples = [ sample * conditioning_scale for sample in down_block_res_samples ]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self . config . global_pool_conditions :
2023-08-17 19:40:15 +00:00
down_block_res_samples = [ torch . mean ( sample , dim = ( 2 , 3 ) , keepdim = True ) for sample in down_block_res_samples ]
2023-07-16 15:08:38 +00:00
mid_block_res_sample = torch . mean ( mid_block_res_sample , dim = ( 2 , 3 ) , keepdim = True )
if not return_dict :
return ( down_block_res_samples , mid_block_res_sample )
return ControlNetOutput (
down_block_res_samples = down_block_res_samples , mid_block_res_sample = mid_block_res_sample
)
2023-07-27 14:54:01 +00:00
2023-07-16 15:08:38 +00:00
diffusers . ControlNetModel = ControlNetModel
diffusers . models . controlnet . ControlNetModel = ControlNetModel