blackify chore

This commit is contained in:
Lincoln Stein 2023-08-17 15:40:15 -04:00
parent 4f82273fc4
commit b5cee7d20c

View File

@ -6,7 +6,13 @@ from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
@ -22,6 +28,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
"""
A ControlNet model.
@ -736,9 +743,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
@ -749,6 +754,5 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel