Update 'monkeypatched' controlnet class

This commit is contained in:
Sergey Borisov 2023-08-14 15:18:54 +03:00 committed by Kent Keirsey
parent e54355f0f3
commit 4f82273fc4
2 changed files with 129 additions and 8 deletions

View File

@ -240,6 +240,7 @@ class InvokeAIDiffuserComponent:
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False, return_dict=False,
) )

View File

@ -4,8 +4,9 @@ import torch
from torch import nn from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
@ -18,10 +19,10 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added # Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(ModelMixin, ConfigMixin):
""" """
A ControlNet model. A ControlNet model.
@ -52,12 +53,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The epsilon to use for the normalization. The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280): cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`): use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0): num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
@ -98,10 +112,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
@ -109,6 +128,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb", controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False, global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
): ):
super().__init__() super().__init__()
@ -136,6 +156,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
@ -145,16 +168,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# time # time
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
timestep_input_dim, timestep_input_dim,
time_embed_dim, time_embed_dim,
act_fn=act_fn, act_fn=act_fn,
) )
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@ -178,6 +228,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
else: else:
self.class_embedding = None self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding # control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0], conditioning_embedding_channels=block_out_channels[0],
@ -212,6 +285,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=layers_per_block, num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
@ -248,6 +322,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
self.controlnet_mid_block = controlnet_block self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn( self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel, in_channels=mid_block_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=norm_eps, resnet_eps=norm_eps,
@ -277,7 +352,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable. where applicable.
""" """
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls( controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels, in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos, flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift, freq_shift=unet.config.freq_shift,
@ -463,6 +553,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False, guess_mode: bool = False,
@ -486,7 +577,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`): encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
@ -549,6 +642,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=sample.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
@ -560,11 +654,34 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb emb = emb + class_emb
if "addition_embed_type" in self.config:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond sample = sample + controlnet_cond
# 3. down # 3. down
@ -619,7 +736,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
mid_block_res_sample = mid_block_res_sample * conditioning_scale mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions: if self.config.global_pool_conditions:
down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples] down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict: if not return_dict:
@ -630,5 +749,6 @@ class ControlNetModel(ModelMixin, ConfigMixin):
) )
diffusers.ControlNetModel = ControlNetModel diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel