Update 'monkeypatched' controlnet class (#4269)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
Should be removed when added in diffusers
https://github.com/huggingface/diffusers/pull/4599
This commit is contained in:
Lincoln Stein 2023-08-17 15:49:54 -04:00 committed by GitHub
commit 832335998f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 6 deletions

View File

@ -240,6 +240,7 @@ class InvokeAIDiffuserComponent:
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False, return_dict=False,
) )

View File

@ -4,8 +4,15 @@ import torch
from torch import nn from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
@ -18,10 +25,11 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added # Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin): class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
""" """
A ControlNet model. A ControlNet model.
@ -52,12 +60,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The epsilon to use for the normalization. The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280): cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`): use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0): num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
@ -98,10 +119,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
@ -109,6 +135,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb", controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False, global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
): ):
super().__init__() super().__init__()
@ -136,6 +163,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
@ -145,16 +175,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# time # time
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
timestep_input_dim, timestep_input_dim,
time_embed_dim, time_embed_dim,
act_fn=act_fn, act_fn=act_fn,
) )
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@ -178,6 +235,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
else: else:
self.class_embedding = None self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding # control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0], conditioning_embedding_channels=block_out_channels[0],
@ -212,6 +292,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=layers_per_block, num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
@ -248,6 +329,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
self.controlnet_mid_block = controlnet_block self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn( self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel, in_channels=mid_block_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=norm_eps, resnet_eps=norm_eps,
@ -277,7 +359,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable. where applicable.
""" """
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls( controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels, in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos, flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift, freq_shift=unet.config.freq_shift,
@ -463,6 +560,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False, guess_mode: bool = False,
@ -486,6 +584,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`): encoder_attention_mask (`torch.Tensor`):
@ -549,6 +649,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=sample.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
@ -560,11 +661,34 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb emb = emb + class_emb
if "addition_embed_type" in self.config:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond sample = sample + controlnet_cond
# 3. down # 3. down