Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.

This commit is contained in:
user1 2023-05-12 01:43:47 -07:00 committed by Kent Keirsey
parent d15bb88eb2
commit 11fc7e40a5
2 changed files with 28 additions and 31 deletions

View File

@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
@ -87,13 +90,13 @@ SAMPLER_NAME_VALUES = Literal[
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@ -255,9 +258,7 @@ class TextToLatentsInvocation(BaseInvocation):
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
print("type of control input: ", type(self.control))
# print("type of control input: ", type(self.control))
if self.control is None:
print("control input is None")
control_list = None
@ -266,14 +267,10 @@ class TextToLatentsInvocation(BaseInvocation):
control_list = None
elif isinstance(self.control, ControlField):
print("control input is ControlField")
# control = [self.control]
control_list = [self.control]
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
print("control input is list[ControlField]")
# print("using first controlnet in list")
control_list = self.control
# control = self.control
else:
print("input control is unrecognized:", type(self.control))
control_list = None
@ -281,25 +278,18 @@ class TextToLatentsInvocation(BaseInvocation):
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
if (control_list is None):
control_models = None
control_weights = None
control_images = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
control_images = []
control_weights = []
for control_info in control_list:
# handle control weights
control_weights.append(control_info.control_weight)
# handle control models
# FIXME: change this to dropdown menu
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
# handle control images
# loading controlnet image (currently requires pre-processed image)
# control_image = prep_control_image(control_info.image)
@ -308,20 +298,26 @@ class TextToLatentsInvocation(BaseInvocation):
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
# do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=True,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
image=input_image,
# do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=True,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_images.append(control_image)
multi_control = MultiControlNetModel(control_models)
model.control_model = multi_control
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight)
control_data.append(control_item)
# multi_control = MultiControlNetModel(control_models)
# model.control_model = multi_control
# model.control_model = control_models
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -330,8 +326,7 @@ class TextToLatentsInvocation(BaseInvocation):
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback,
control_image=control_images,
control_weight=control_weights,
control_data=control_data, # list[ControlNetData]
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699

View File

@ -670,6 +670,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
@ -699,6 +700,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input,