mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
This commit is contained in:
parent
d15bb88eb2
commit
11fc7e40a5
@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
@ -87,13 +90,13 @@ SAMPLER_NAME_VALUES = Literal[
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@ -255,9 +258,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
print("type of control input: ", type(self.control))
|
||||
|
||||
# print("type of control input: ", type(self.control))
|
||||
if self.control is None:
|
||||
print("control input is None")
|
||||
control_list = None
|
||||
@ -266,14 +267,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_list = None
|
||||
elif isinstance(self.control, ControlField):
|
||||
print("control input is ControlField")
|
||||
# control = [self.control]
|
||||
control_list = [self.control]
|
||||
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
|
||||
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
||||
print("control input is list[ControlField]")
|
||||
# print("using first controlnet in list")
|
||||
control_list = self.control
|
||||
# control = self.control
|
||||
else:
|
||||
print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
@ -281,25 +278,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
||||
if (control_list is None):
|
||||
control_models = None
|
||||
control_weights = None
|
||||
control_images = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
control_images = []
|
||||
control_weights = []
|
||||
for control_info in control_list:
|
||||
# handle control weights
|
||||
control_weights.append(control_info.control_weight)
|
||||
|
||||
# handle control models
|
||||
# FIXME: change this to dropdown menu
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
|
||||
# handle control images
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
# control_image = prep_control_image(control_info.image)
|
||||
@ -308,20 +298,26 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
image=input_image,
|
||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=True,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
image=input_image,
|
||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=True,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
control_images.append(control_image)
|
||||
multi_control = MultiControlNetModel(control_models)
|
||||
model.control_model = multi_control
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight)
|
||||
control_data.append(control_item)
|
||||
# multi_control = MultiControlNetModel(control_models)
|
||||
# model.control_model = multi_control
|
||||
# model.control_model = control_models
|
||||
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
@ -330,8 +326,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
control_image=control_images,
|
||||
control_weight=control_weights,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
|
@ -670,6 +670,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
latent_control_input = latent_model_input
|
||||
# control_data should be type List[ControlNetData]
|
||||
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
@ -699,6 +700,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
latent_model_input,
|
||||
|
Loading…
Reference in New Issue
Block a user