Core implementation of ControlNet and MultiControlNet.

This commit is contained in:
user1 2023-04-29 00:43:21 -07:00 committed by Kent Keirsey
parent 5569f205ee
commit 5ff98a4179
2 changed files with 115 additions and 13 deletions

View File

@ -9,16 +9,20 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import einops import einops
import PIL.Image import PIL.Image
import numpy as np
from accelerate.utils import set_seed from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from compel import EmbeddingsProvider from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
) )
@ -27,6 +31,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
) )
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -302,6 +307,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32", precision: str = "float32",
control_model: ControlNetModel = None,
): ):
super().__init__( super().__init__(
vae, vae,
@ -322,6 +328,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent( self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True self.unet, self._unet_forward, is_running_diffusers=True
@ -341,6 +349,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels) self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
""" """
@ -463,6 +472,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -483,6 +493,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -507,6 +518,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu') scheduler_device = torch.device('cpu')
@ -527,6 +539,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -539,6 +552,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
**kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
@ -578,6 +592,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -618,6 +633,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -629,6 +645,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
# handling case where using multiple control models but only specifying single control_scale
# so reshape control_scale to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
control_scale = [control_scale] * len(self.control_model.nets)
if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# controlnet inference
down_block_res_samples, mid_block_res_sample = self.control_model(
latent_control_input,
timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_image,
conditioning_scale=control_scale,
return_dict=False,
)
else:
down_block_res_samples, mid_block_res_sample = None, None
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, latent_model_input,
@ -638,6 +681,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -659,6 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t, t,
text_embeddings, text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None, cross_attention_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
): ):
"""predict the noise residual""" """predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4: if is_inpainting_model(self.unet) and latents.size(1) == 4:
@ -678,7 +724,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# First three args should be positional, not keywords, so torch hooks can see them. # First three args should be positional, not keywords, so torch hooks can see them.
return self.unet( return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
).sample ).sample
def img2img_from_embeddings( def img2img_from_embeddings(
@ -940,3 +987,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
def prepare_control_image(
self,
image,
width=512,
height=512,
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image

View File

@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: Optional[int] = None, step_index: Optional[int] = None,
total_step_count: Optional[int] = None, total_step_count: Optional[int] = None,
**kwargs,
): ):
""" """
:param x: current latents :param x: current latents
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning: if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
elif wants_cross_attention_control: elif wants_cross_attention_control:
( (
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
else: else:
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
combined_next_x = self._combine( combined_next_x = self._combine(
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path # fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning: torch.Tensor, unconditioning: torch.Tensor,
conditioning: torch.Tensor, conditioning: torch.Tensor,
**kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug. # prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone() conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict) assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
else: else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback( unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
).chunk(2) ).chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
if self.is_running_diffusers: if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers( return self._apply_cross_attention_controlled_conditioning__diffusers(
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
else: else:
return self._apply_cross_attention_controlled_conditioning__compvis( return self._apply_cross_attention_controlled_conditioning__compvis(
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
def _apply_cross_attention_controlled_conditioning__diffusers( def _apply_cross_attention_controlled_conditioning__diffusers(
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
# do requested cross attention types for conditioning (positive prompt) # do requested cross attention types for conditioning (positive prompt)
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do) # print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do: for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type) context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
context.clear_requests(cleanup=False) context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
self.conditioning.cross_attention_control_args.edited_conditioning self.conditioning.cross_attention_control_args.edited_conditioning
) )
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning x, sigma, edited_conditioning, **kwargs,
) )
context.clear_requests(cleanup=True) context.clear_requests(cleanup=True)