First pass at ControlNet "guess mode" implementation.

This commit is contained in:
user1 2023-06-11 02:00:39 -07:00
parent c647056287
commit fd715026a7
3 changed files with 93 additions and 25 deletions

View File

@ -1,7 +1,7 @@
# InvokeAI nodes for ControlNet image preprocessors # InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float from builtins import float, bool
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
] ]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
# CONTROLNET_MODE_VALUES = Literal[tuple(["BALANCED", "PROMPT", "CONTROL"])]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
@ -104,6 +105,7 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
guess_mode: bool = Field(default=False, description="Toggle for guess mode")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def abs_le_one(cls, v):
"""validate that all abs(values) are <=1""" """validate that all abs(values) are <=1"""
@ -149,6 +151,7 @@ class ControlNetInvocation(BaseInvocation):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
guess_mode: bool = Field(default=False, description="Toggle for guess mode")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -174,6 +177,7 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
guess_mode=self.guess_mode,
), ),
) )

View File

@ -337,12 +337,14 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
guess_mode=control_info.guess_mode,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent) end_step_percent=control_info.end_step_percent,
guess_mode=control_info.guess_mode,)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data

View File

@ -221,6 +221,8 @@ class ControlNetData:
weight: Union[float, List[float]] = Field(default=1.0) weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -656,21 +658,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward? # TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# # guess mode handling from diffusers
# if guess_mode and do_classifier_free_guidance:
# # Infer ControlNet only for the conditional batch.
# control_model_input = latents
# control_model_input = self.scheduler.scale_model_input(control_model_input, t)
# controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
# else:
# control_model_input = unet_latent_input
# controlnet_prompt_embeds = prompt_embeds
# default is no controlnet, so set controlnet processing output to None # default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting # FIXME: make sure guidance_scale <= 1.0 is handled correctly if doing per-step guidance setting
# UPDATE: I think this is fixed now with pydantic validator for cfg_scale?
# So we should _never_ have guidance_scale <= 1.0
# if conditioning_data.guidance_scale > 1.0: # if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None: # if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance # if guess_mode is False:
# (which I think for now is always true, there is conditional elsewhere that stops execution if # # expand the latents input to control model if doing classifier free guidance
# classifier_free_guidance is <= 1.0 ?) # # (which I think for now is always true, there is conditional elsewhere that stops execution if
latent_control_input = torch.cat([latent_model_input] * 2) # # classifier_free_guidance is <= 1.0 ?)
else: # control_latent_input = torch.cat([unet_latent_input] * 2)
latent_control_input = latent_model_input # else:
# control_latent_input = unet_latent_input
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list)
@ -680,24 +695,62 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index) guess_mode = control_datum.guess_mode
if guess_mode:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
print("running controlnet", i, "for step", step_index)
print("guess mode: ", guess_mode)
print("guess mode type: ", type(guess_mode))
if guess_mode: # only using prompt conditioning in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
print("encoder_hidden_states.shape", encoder_hidden_states.shape)
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]
else: else:
# if controlnet has a single weight, use it for all steps # if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight controlnet_weight = control_datum.weight
# guess mode handling from diffusers controlnet pipeline:
# if guess_mode and do_classifier_free_guidance:
# # Infer ControlNet only for the conditional batch.
# latent_control_input = latents
# latent_control_input = self.scheduler.scale_model_input(control_model_input, t)
# controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
# else:
# control_model_input = unet_latent_input
# controlnet_prompt_embeds = prompt_embeds
# controlnet(s) inference
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=latent_control_input, sample=control_latent_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, # encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]), # conditioning_data.text_embeddings]),
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
# cross_attention_kwargs, # cross_attention_kwargs,
guess_mode=False, guess_mode=guess_mode,
return_dict=False, return_dict=False,
) )
print("finished ControlNetModel() call, step", step_index)
if guess_mode:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None: if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else: else:
@ -708,13 +761,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
] ]
mid_block_res_sample += mid_sample mid_block_res_sample += mid_sample
# guess mode handling from diffusers controlnet pipeline:
# if guess_mode and do_classifier_free_guidance:
# # Inferred ControlNet only for the conditional batch.
# # To apply the output of ControlNet to both the unconditional and conditional batches,
# # add 0 to the unconditional batch to keep it unchanged.
# down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
# mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, x=unet_latent_input,
t, sigma=t,
conditioning_data.unconditioned_embeddings, unconditioning=conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings, conditioning=conditioning_data.text_embeddings,
conditioning_data.guidance_scale, unconditional_guidance_scale=conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s) down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
@ -1038,6 +1099,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device="cuda", device="cuda",
dtype=torch.float16, dtype=torch.float16,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
guess_mode=False,
): ):
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
@ -1068,6 +1130,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0) image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance: if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2) image = torch.cat([image] * 2)
return image return image