From 5ff98a417963cb84941fca28de237def1ea96255 Mon Sep 17 00:00:00 2001 From: user1 Date: Sat, 29 Apr 2023 00:43:21 -0700 Subject: [PATCH] Core implementation of ControlNet and MultiControlNet. --- .../stable_diffusion/diffusers_pipeline.py | 94 ++++++++++++++++++- .../diffusion/shared_invokeai_diffusion.py | 34 ++++--- 2 files changed, 115 insertions(+), 13 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 4ca2a5cb30..758779b735 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -9,16 +9,20 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union import einops import PIL.Image +import numpy as np from accelerate.utils import set_seed import psutil import torch import torchvision.transforms as T from compel import EmbeddingsProvider from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( StableDiffusionImg2ImgPipeline, ) @@ -27,6 +31,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import ( ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.utils import PIL_INTERPOLATION from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize @@ -302,6 +307,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, precision: str = "float32", + control_model: ControlNetModel = None, ): super().__init__( vae, @@ -322,6 +328,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + # FIXME: can't currently register control module + # control_model=control_model, ) self.invokeai_diffuser = InvokeAIDiffuserComponent( self.unet, self._unet_forward, is_running_diffusers=True @@ -341,6 +349,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group.install(*self._submodels) + self.control_model = control_model def _adjust_memory_efficient_attention(self, latents: torch.Tensor): """ @@ -463,6 +472,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, callback: Callable[[PipelineIntermediateState], None] = None, run_id=None, + **kwargs, ) -> InvokeAIStableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. @@ -483,6 +493,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise=noise, run_id=run_id, callback=callback, + **kwargs, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -507,6 +518,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device('cpu') @@ -527,6 +539,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance=additional_guidance, run_id=run_id, callback=callback, + **kwargs, ) return result.latents, result.attention_map_saver @@ -539,6 +552,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None, + **kwargs, ): self._adjust_memory_efficient_attention(latents) if run_id is None: @@ -578,6 +592,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=i, total_step_count=len(timesteps), additional_guidance=additional_guidance, + **kwargs, ) latents = step_output.prev_sample @@ -618,6 +633,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index: int, total_step_count: int, additional_guidance: List[Callable] = None, + **kwargs, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -629,6 +645,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) + if (self.control_model is not None) and (kwargs.get("control_image") is not None): + control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s) + control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0 + # handling case where using multiple control models but only specifying single control_scale + # so reshape control_scale to match number of control models + if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float): + control_scale = [control_scale] * len(self.control_model.nets) + if conditioning_data.guidance_scale > 1.0: + # expand the latents input to control model if doing classifier free guidance + # (which I think for now is always true, there is conditional elsewhere that stops execution if + # classifier_free_guidance is <= 1.0 ?) + latent_control_input = torch.cat([latent_model_input] * 2) + else: + latent_control_input = latent_model_input + # controlnet inference + down_block_res_samples, mid_block_res_sample = self.control_model( + latent_control_input, + timestep, + encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings]), + controlnet_cond=control_image, + conditioning_scale=control_scale, + return_dict=False, + ) + else: + down_block_res_samples, mid_block_res_sample = None, None + # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step( latent_model_input, @@ -638,6 +681,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data.guidance_scale, step_index=step_index, total_step_count=total_step_count, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, ) # compute the previous noisy sample x_t -> x_t-1 @@ -659,6 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): t, text_embeddings, cross_attention_kwargs: Optional[dict[str, Any]] = None, + **kwargs, ): """predict the noise residual""" if is_inpainting_model(self.unet) and latents.size(1) == 4: @@ -678,7 +724,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # First three args should be positional, not keywords, so torch hooks can see them. return self.unet( - latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs + latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs, + **kwargs, ).sample def img2img_from_embeddings( @@ -940,3 +987,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): debug_image( img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True ) + + # Copied from diffusers pipeline_stable_diffusion_controlnet.py + # Returns torch.Tensor of shape (batch_size, 3, height, width) + def prepare_control_image( + self, + image, + width=512, + height=512, + batch_size=1, + num_images_per_prompt=1, + device="cuda", + dtype=torch.float16, + do_classifier_free_guidance=True, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + image = images + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + return image diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 4131837b41..d05565c506 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent: unconditional_guidance_scale: float, step_index: Optional[int] = None, total_step_count: Optional[int] = None, + **kwargs, ): """ :param x: current latents @@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent: if wants_hybrid_conditioning: unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) elif wants_cross_attention_control: ( @@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) elif self.sequential_guidance: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) else: @@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) combined_next_x = self._combine( @@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): # fast batched path x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) both_conditionings = torch.cat([unconditioning, conditioning]) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings + x_twice, sigma_twice, both_conditionings, **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) if conditioned_next_x.device.type == "mps": @@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent: sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor, + **kwargs, ): # low-memory sequential path - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) if conditioned_next_x.device.type == "mps": # prevent a result filled with zeros. seems to be a torch bug. conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x - def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict) x_twice = torch.cat([x] * 2) @@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent: else: both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings + x_twice, sigma_twice, both_conditionings, **kwargs, ).chunk(2) return unconditioned_next_x, conditioned_next_x @@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): if self.is_running_diffusers: return self._apply_cross_attention_controlled_conditioning__diffusers( @@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) else: return self._apply_cross_attention_controlled_conditioning__compvis( @@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) def _apply_cross_attention_controlled_conditioning__diffusers( @@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): context: Context = self.cross_attention_control_context @@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent: sigma, unconditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + **kwargs, ) # do requested cross attention types for conditioning (positive prompt) @@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent: sigma, conditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + **kwargs, ) return unconditioned_next_x, conditioned_next_x @@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) @@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent: context: Context = self.cross_attention_control_context try: - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) # process x using the original prompt, saving the attention maps # print("saving attention maps for", cross_attention_control_types_to_do) for ca_type in cross_attention_control_types_to_do: context.request_save_attention_maps(ca_type) - _ = self.model_forward_callback(x, sigma, conditioning) + _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,) context.clear_requests(cleanup=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied @@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent: self.conditioning.cross_attention_control_args.edited_conditioning ) conditioned_next_x = self.model_forward_callback( - x, sigma, edited_conditioning + x, sigma, edited_conditioning, **kwargs, ) context.clear_requests(cleanup=True)