diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 4710682ac1..a02294f720 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -98,6 +98,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, @@ -135,6 +136,8 @@ class ControlNetModel(ModelMixin, ConfigMixin): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # input conv_in_kernel = 3 @@ -212,6 +215,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): down_block = get_down_block( down_block_type, num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -248,6 +252,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps,