From 4f82273fc4ac8846791c7196deea672d23e2c9ef Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 14 Aug 2023 15:18:54 +0300 Subject: [PATCH 1/2] Update 'monkeypatched' controlnet class --- .../diffusion/shared_invokeai_diffusion.py | 1 + invokeai/backend/util/hotfixes.py | 136 ++++++++++++++++-- 2 files changed, 129 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index e739855b9e..f16855e775 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -240,6 +240,7 @@ class InvokeAIDiffuserComponent: controlnet_cond=control_datum.image_tensor, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale encoder_attention_mask=encoder_attention_mask, + added_cond_kwargs=added_cond_kwargs, guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel return_dict=False, ) diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 4710682ac1..9c643d13bc 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -4,8 +4,9 @@ import torch from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalControlnetMixin from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -18,10 +19,10 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel import diffusers from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module +# TODO: create PR to diffusers # Modified ControlNetModel with encoder_attention_mask argument added - -class ControlNetModel(ModelMixin, ConfigMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model. @@ -52,12 +53,25 @@ class ControlNetModel(ModelMixin, ConfigMixin): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. @@ -98,10 +112,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", @@ -109,6 +128,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, ): super().__init__() @@ -136,6 +156,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 @@ -145,16 +168,43 @@ class ControlNetModel(ModelMixin, ConfigMixin): # time time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -178,6 +228,29 @@ class ControlNetModel(ModelMixin, ConfigMixin): else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], @@ -212,6 +285,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): down_block = get_down_block( down_block_type, num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -248,6 +322,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, @@ -277,7 +352,22 @@ class ControlNetModel(ModelMixin, ConfigMixin): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, @@ -463,6 +553,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, guess_mode: bool = False, @@ -486,7 +577,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If @@ -549,6 +642,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None if self.class_embedding is not None: if class_labels is None: @@ -560,11 +654,34 @@ class ControlNetModel(ModelMixin, ConfigMixin): class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + # 2. pre-process sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample = sample + controlnet_cond # 3. down @@ -619,7 +736,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: - down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples] + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: @@ -630,5 +749,6 @@ class ControlNetModel(ModelMixin, ConfigMixin): ) + diffusers.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel From b5cee7d20ce6c45f18bedcd67e2f545359b5a3d0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 17 Aug 2023 15:40:15 -0400 Subject: [PATCH 2/2] blackify chore --- invokeai/backend/util/hotfixes.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 9c643d13bc..89b3da5a37 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -6,7 +6,13 @@ from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalControlnetMixin from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.embeddings import ( + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -22,6 +28,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control # TODO: create PR to diffusers # Modified ControlNetModel with encoder_attention_mask argument added + class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model. @@ -736,9 +743,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] + down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: @@ -749,6 +754,5 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ) - diffusers.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel