mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
blackify chore
This commit is contained in:
parent
4f82273fc4
commit
b5cee7d20c
@ -6,7 +6,13 @@ from torch import nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlnetMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
@ -22,6 +28,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control
|
||||
# TODO: create PR to diffusers
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
@ -736,9 +743,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
@ -749,6 +754,5 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
)
|
||||
|
||||
|
||||
|
||||
diffusers.ControlNetModel = ControlNetModel
|
||||
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
||||
|
Loading…
Reference in New Issue
Block a user