From 63d248622cbc0c391879f93279b23a347f5787c5 Mon Sep 17 00:00:00 2001 From: user1 Date: Fri, 12 May 2023 01:43:47 -0700 Subject: [PATCH] Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. --- invokeai/app/invocations/latent.py | 53 ++++++++-------- .../stable_diffusion/diffusers_pipeline.py | 60 +++++++++++++------ 2 files changed, 65 insertions(+), 48 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1d1eb1963a..b9ce5c10a8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding from ...backend.prompting.conditioning import get_uc_and_c_and_ec + from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP +from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData + from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np from ..services.image_file_storage import ImageType @@ -260,9 +263,7 @@ class TextToLatentsInvocation(BaseInvocation): model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(context, model) - - print("type of control input: ", type(self.control)) - + # print("type of control input: ", type(self.control)) if self.control is None: print("control input is None") control_list = None @@ -271,14 +272,10 @@ class TextToLatentsInvocation(BaseInvocation): control_list = None elif isinstance(self.control, ControlField): print("control input is ControlField") - # control = [self.control] control_list = [self.control] - # elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField): elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField): print("control input is list[ControlField]") - # print("using first controlnet in list") control_list = self.control - # control = self.control else: print("input control is unrecognized:", type(self.control)) control_list = None @@ -286,25 +283,18 @@ class TextToLatentsInvocation(BaseInvocation): #if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)): if (control_list is None): control_models = None - control_weights = None - control_images = None # from above handling, any control that is not None should now be of type list[ControlField] else: # FIXME: add checks to skip entry if model or image is None # and if weight is None, populate with default 1.0? + control_data = [] control_models = [] - control_images = [] - control_weights = [] for control_info in control_list: - # handle control weights - control_weights.append(control_info.control_weight) - # handle control models # FIXME: change this to dropdown menu control_model = ControlNetModel.from_pretrained(control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) - # handle control images # loading controlnet image (currently requires pre-processed image) # control_image = prep_control_image(control_info.image) @@ -313,20 +303,26 @@ class TextToLatentsInvocation(BaseInvocation): # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? # and do real check for classifier_free_guidance? + # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) control_image = model.prepare_control_image( - image=input_image, - # do_classifier_free_guidance=do_classifier_free_guidance, - do_classifier_free_guidance=True, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, + image=input_image, + # do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=True, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, ) - control_images.append(control_image) - multi_control = MultiControlNetModel(control_models) - model.control_model = multi_control + control_item = ControlNetData(model=control_model, + image_tensor=control_image, + weight=control_info.control_weight) + control_data.append(control_item) + # multi_control = MultiControlNetModel(control_models) + # model.control_model = multi_control + # model.control_model = control_models + # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( @@ -335,8 +331,7 @@ class TextToLatentsInvocation(BaseInvocation): num_inference_steps=self.steps, conditioning_data=conditioning_data, callback=step_callback, - control_image=control_images, - control_weight=control_weights, + control_data=control_data, # list[ControlNetData] ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 25154d61cd..a6c365967c 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -2,10 +2,12 @@ from __future__ import annotations import dataclasses import inspect +import math import secrets from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel, Field import einops import PIL.Image @@ -212,6 +214,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): raise AssertionError("why was that an empty generator?") return result +@dataclass +class ControlNetData: + model: ControlNetModel = Field(default=None) + image_tensor: torch.Tensor= Field(default=None) + weight: float = Field(default=1.0) + @dataclass(frozen=True) class ConditioningData: @@ -518,6 +526,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, + control_data: List[ControlNetData] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): @@ -539,6 +548,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance=additional_guidance, run_id=run_id, callback=callback, + control_data=control_data, **kwargs, ) return result.latents, result.attention_map_saver @@ -552,6 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None, + control_data: List[ControlNetData] = None, **kwargs, ): self._adjust_memory_efficient_attention(latents) @@ -582,7 +593,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latents = self.scheduler.add_noise(latents, noise, batched_t) attention_map_saver: Optional[AttentionMapSaver] = None - + # print("timesteps:", timesteps) for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step( @@ -592,6 +603,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=i, total_step_count=len(timesteps), additional_guidance=additional_guidance, + control_data=control_data, **kwargs, ) latents = step_output.prev_sample @@ -633,11 +645,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index: int, total_step_count: int, additional_guidance: List[Callable] = None, + control_data: List[ControlNetData] = None, **kwargs, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] - if additional_guidance is None: additional_guidance = [] @@ -645,13 +657,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) - if (self.control_model is not None) and (kwargs.get("control_image") is not None): - control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s) - control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0 - # handling case where using multiple control models but only specifying single control_weight - # so reshape control_weight to match number of control models - if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float): - control_weight = [control_weight] * len(self.control_model.nets) + # if (self.control_model is not None) and (control_image is not None): + if control_data is not None: if conditioning_data.guidance_scale > 1.0: # expand the latents input to control model if doing classifier free guidance # (which I think for now is always true, there is conditional elsewhere that stops execution if @@ -659,16 +666,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latent_control_input = torch.cat([latent_model_input] * 2) else: latent_control_input = latent_model_input - # controlnet inference - down_block_res_samples, mid_block_res_sample = self.control_model( - latent_control_input, - timestep, - encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings]), - controlnet_cond=control_image, - conditioning_scale=control_weight, - return_dict=False, - ) + # control_data should be type List[ControlNetData] + # this loop covers both ControlNet (1 ControlNetData in list) + # and MultiControlNet (multiple ControlNetData in list) + for i, control_datum in enumerate(control_data): + # print("controlnet", i, "==>", type(control_datum)) + down_samples, mid_sample = control_datum.model( + sample=latent_control_input, + timestep=timestep, + encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings]), + controlnet_cond=control_datum.image_tensor, + conditioning_scale=control_datum.weight, + # cross_attention_kwargs, + guess_mode=False, + return_dict=False, + ) + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + # add controlnet outputs together if have multiple controlnets + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample else: down_block_res_samples, mid_block_res_sample = None, None