From 11fc7e40a588ddfd8148bb96df27aa00569479c3 Mon Sep 17 00:00:00 2001 From: user1 Date: Fri, 12 May 2023 01:43:47 -0700 Subject: [PATCH] Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. --- invokeai/app/invocations/latent.py | 57 +++++++++---------- .../stable_diffusion/diffusers_pipeline.py | 2 + 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index f6516ee99f..94775e9a44 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding from ...backend.prompting.conditioning import get_uc_and_c_and_ec + from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP +from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData + from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np from ..services.image_storage import ImageType @@ -87,13 +90,13 @@ SAMPLER_NAME_VALUES = Literal[ def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - + scheduler_config = model.scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -255,9 +258,7 @@ class TextToLatentsInvocation(BaseInvocation): model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(context, model) - - print("type of control input: ", type(self.control)) - + # print("type of control input: ", type(self.control)) if self.control is None: print("control input is None") control_list = None @@ -266,14 +267,10 @@ class TextToLatentsInvocation(BaseInvocation): control_list = None elif isinstance(self.control, ControlField): print("control input is ControlField") - # control = [self.control] control_list = [self.control] - # elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField): elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField): print("control input is list[ControlField]") - # print("using first controlnet in list") control_list = self.control - # control = self.control else: print("input control is unrecognized:", type(self.control)) control_list = None @@ -281,25 +278,18 @@ class TextToLatentsInvocation(BaseInvocation): #if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)): if (control_list is None): control_models = None - control_weights = None - control_images = None # from above handling, any control that is not None should now be of type list[ControlField] else: # FIXME: add checks to skip entry if model or image is None # and if weight is None, populate with default 1.0? + control_data = [] control_models = [] - control_images = [] - control_weights = [] for control_info in control_list: - # handle control weights - control_weights.append(control_info.control_weight) - # handle control models # FIXME: change this to dropdown menu control_model = ControlNetModel.from_pretrained(control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) - # handle control images # loading controlnet image (currently requires pre-processed image) # control_image = prep_control_image(control_info.image) @@ -308,20 +298,26 @@ class TextToLatentsInvocation(BaseInvocation): # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? # and do real check for classifier_free_guidance? + # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) control_image = model.prepare_control_image( - image=input_image, - # do_classifier_free_guidance=do_classifier_free_guidance, - do_classifier_free_guidance=True, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, + image=input_image, + # do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=True, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, ) - control_images.append(control_image) - multi_control = MultiControlNetModel(control_models) - model.control_model = multi_control + control_item = ControlNetData(model=control_model, + image_tensor=control_image, + weight=control_info.control_weight) + control_data.append(control_item) + # multi_control = MultiControlNetModel(control_models) + # model.control_model = multi_control + # model.control_model = control_models + # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( @@ -330,8 +326,7 @@ class TextToLatentsInvocation(BaseInvocation): num_inference_steps=self.steps, conditioning_data=conditioning_data, callback=step_callback, - control_image=control_images, - control_weight=control_weights, + control_data=control_data, # list[ControlNetData] ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 9656fe7eee..5fe8289741 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -670,6 +670,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): else: latent_control_input = latent_model_input # control_data should be type List[ControlNetData] + # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) for i, control_datum in enumerate(control_data): @@ -699,6 +700,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) ] mid_block_res_sample += mid_sample + # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step( latent_model_input,