Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.

This commit is contained in:
user1 2023-05-12 01:43:47 -07:00 committed by Kent Keirsey
parent 48485fe92f
commit 63d248622c
2 changed files with 65 additions and 48 deletions

View File

@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_file_storage import ImageType from ..services.image_file_storage import ImageType
@ -260,9 +263,7 @@ class TextToLatentsInvocation(BaseInvocation):
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
# print("type of control input: ", type(self.control))
print("type of control input: ", type(self.control))
if self.control is None: if self.control is None:
print("control input is None") print("control input is None")
control_list = None control_list = None
@ -271,14 +272,10 @@ class TextToLatentsInvocation(BaseInvocation):
control_list = None control_list = None
elif isinstance(self.control, ControlField): elif isinstance(self.control, ControlField):
print("control input is ControlField") print("control input is ControlField")
# control = [self.control]
control_list = [self.control] control_list = [self.control]
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField): elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
print("control input is list[ControlField]") print("control input is list[ControlField]")
# print("using first controlnet in list")
control_list = self.control control_list = self.control
# control = self.control
else: else:
print("input control is unrecognized:", type(self.control)) print("input control is unrecognized:", type(self.control))
control_list = None control_list = None
@ -286,25 +283,18 @@ class TextToLatentsInvocation(BaseInvocation):
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)): #if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
if (control_list is None): if (control_list is None):
control_models = None control_models = None
control_weights = None
control_images = None
# from above handling, any control that is not None should now be of type list[ControlField] # from above handling, any control that is not None should now be of type list[ControlField]
else: else:
# FIXME: add checks to skip entry if model or image is None # FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0? # and if weight is None, populate with default 1.0?
control_data = []
control_models = [] control_models = []
control_images = []
control_weights = []
for control_info in control_list: for control_info in control_list:
# handle control weights
control_weights.append(control_info.control_weight)
# handle control models # handle control models
# FIXME: change this to dropdown menu # FIXME: change this to dropdown menu
control_model = ControlNetModel.from_pretrained(control_info.control_model, control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
# handle control images # handle control images
# loading controlnet image (currently requires pre-processed image) # loading controlnet image (currently requires pre-processed image)
# control_image = prep_control_image(control_info.image) # control_image = prep_control_image(control_info.image)
@ -313,6 +303,7 @@ class TextToLatentsInvocation(BaseInvocation):
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance? # and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image( control_image = model.prepare_control_image(
image=input_image, image=input_image,
# do_classifier_free_guidance=do_classifier_free_guidance, # do_classifier_free_guidance=do_classifier_free_guidance,
@ -324,9 +315,14 @@ class TextToLatentsInvocation(BaseInvocation):
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
) )
control_images.append(control_image) control_item = ControlNetData(model=control_model,
multi_control = MultiControlNetModel(control_models) image_tensor=control_image,
model.control_model = multi_control weight=control_info.control_weight)
control_data.append(control_item)
# multi_control = MultiControlNetModel(control_models)
# model.control_model = multi_control
# model.control_model = control_models
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -335,8 +331,7 @@ class TextToLatentsInvocation(BaseInvocation):
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback, callback=step_callback,
control_image=control_images, control_data=control_data, # list[ControlNetData]
control_weight=control_weights,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699

View File

@ -2,10 +2,12 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import math
import secrets import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
import einops import einops
import PIL.Image import PIL.Image
@ -212,6 +214,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?") raise AssertionError("why was that an empty generator?")
return result return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0)
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -518,6 +526,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs, **kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
@ -539,6 +548,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
control_data=control_data,
**kwargs, **kwargs,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -552,6 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs, **kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
@ -582,7 +593,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step( step_output = self.step(
@ -592,6 +603,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data,
**kwargs, **kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -633,11 +645,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs, **kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -645,13 +657,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
if (self.control_model is not None) and (kwargs.get("control_image") is not None): # if (self.control_model is not None) and (control_image is not None):
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s) if control_data is not None:
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
# handling case where using multiple control models but only specifying single control_weight
# so reshape control_weight to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
control_weight = [control_weight] * len(self.control_model.nets)
if conditioning_data.guidance_scale > 1.0: if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance # expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if # (which I think for now is always true, there is conditional elsewhere that stops execution if
@ -659,16 +666,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latent_control_input = torch.cat([latent_model_input] * 2) latent_control_input = torch.cat([latent_model_input] * 2)
else: else:
latent_control_input = latent_model_input latent_control_input = latent_model_input
# controlnet inference # control_data should be type List[ControlNetData]
down_block_res_samples, mid_block_res_sample = self.control_model( # this loop covers both ControlNet (1 ControlNetData in list)
latent_control_input, # and MultiControlNet (multiple ControlNetData in list)
timestep, for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]), conditioning_data.text_embeddings]),
controlnet_cond=control_image, controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_weight, conditioning_scale=control_datum.weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False, return_dict=False,
) )
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
else: else:
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None