Switched over to ControlNet control_mode with 4 options: balanced, more_prompt, more_control, even_more_control. Based on True/False combinations of internal booleans cfg_injection and soft_injection

This commit is contained in:
user1 2023-06-13 21:08:34 -07:00
parent 8495764d45
commit de3e6cdb02
3 changed files with 27 additions and 88 deletions

View File

@ -94,7 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
] ]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
# CONTROLNET_MODE_VALUES = Literal[tuple(["BALANCED", "PROMPT", "CONTROL"])] CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "even_more_control"])]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
@ -105,9 +105,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
# guess_mode: bool = Field(default=False, description="Toggle for guess mode") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
cfg_injection: bool = Field(default=False, description="Toggle for cfg injection")
soft_injection: bool = Field(default=False, description="Toggle for soft injection")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def abs_le_one(cls, v):
"""validate that all abs(values) are <=1""" """validate that all abs(values) are <=1"""
@ -148,14 +147,11 @@ class ControlNetInvocation(BaseInvocation):
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
# guess_mode: bool = Field(default=False, description="Toggle for guess mode") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
cfg_injection: bool = Field(default=False, description="Toggle for cfg injection")
soft_injection: bool = Field(default=False, description="Toggle for soft injection")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -173,7 +169,6 @@ class ControlNetInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
control=ControlField( control=ControlField(
image=self.image, image=self.image,
@ -181,9 +176,7 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
# guess_mode=self.guess_mode, control_mode=self.control_mode,
cfg_injection=self.cfg_injection,
soft_injection=self.soft_injection,
), ),
) )

View File

@ -282,19 +282,14 @@ class TextToLatentsInvocation(BaseInvocation):
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * 8
if control_input is None: if control_input is None:
# print("control input is None")
control_list = None control_list = None
elif isinstance(control_input, list) and len(control_input) == 0: elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None control_list = None
elif isinstance(control_input, ControlField): elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input] control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input control_list = control_input
else: else:
# print("input control is unrecognized:", type(self.control))
control_list = None control_list = None
if (control_list is None): if (control_list is None):
control_data = None control_data = None
@ -337,18 +332,14 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
# guess_mode=control_info.guess_mode, control_mode=control_info.control_mode,
cfg_injection=control_info.cfg_injection,
soft_injection=control_info.soft_injection,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
# guess_mode=control_info.guess_mode, control_mode=control_info.control_mode,
cfg_injection=control_info.cfg_injection,
soft_injection=control_info.soft_injection,
) )
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]

View File

@ -221,10 +221,8 @@ class ControlNetData:
weight: Union[float, List[float]] = Field(default=1.0) weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL control_mode: str = Field(default="balanced")
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
cfg_injection: bool = Field(default=False)
soft_injection: bool = Field(default=False)
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -662,44 +660,30 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
unet_latent_input = self.scheduler.scale_model_input(latents, timestep) unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# # guess mode handling from diffusers
# if guess_mode and do_classifier_free_guidance:
# # Infer ControlNet only for the conditional batch.
# control_model_input = latents
# control_model_input = self.scheduler.scale_model_input(control_model_input, t)
# controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
# else:
# control_model_input = unet_latent_input
# controlnet_prompt_embeds = prompt_embeds
# default is no controlnet, so set controlnet processing output to None # default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
# FIXME: make sure guidance_scale <= 1.0 is handled correctly if doing per-step guidance setting
# UPDATE: I think this is fixed now with pydantic validator for cfg_scale?
# So we should _never_ have guidance_scale <= 1.0
# if conditioning_data.guidance_scale > 1.0:
# if conditioning_data.guidance_scale is not None:
# if guess_mode is False:
# # expand the latents input to control model if doing classifier free guidance
# # (which I think for now is always true, there is conditional elsewhere that stops execution if
# # classifier_free_guidance is <= 1.0 ?)
# control_latent_input = torch.cat([unet_latent_input] * 2)
# else:
# control_latent_input = unet_latent_input
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data): for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum)) control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "even_more_control")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
# guess_mode = control_datum.guess_mode
guess_mode = control_datum.cfg_injection if cfg_injection:
if guess_mode:
control_latent_input = unet_latent_input control_latent_input = unet_latent_input
else: else:
# expand the latents input to control model if doing classifier free guidance # expand the latents input to control model if doing classifier free guidance
@ -707,15 +691,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# classifier_free_guidance is <= 1.0 ?) # classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2) control_latent_input = torch.cat([unet_latent_input] * 2)
print("running controlnet", i, "for step", step_index) if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
print("guess mode: ", guess_mode)
print("guess mode type: ", type(guess_mode))
if guess_mode: # only using prompt conditioning in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings]) encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else: else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]) conditioning_data.text_embeddings])
print("encoder_hidden_states.shape", encoder_hidden_states.shape)
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]
@ -723,32 +703,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# if controlnet has a single weight, use it for all steps # if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight controlnet_weight = control_datum.weight
# guess mode handling from diffusers controlnet pipeline:
# if guess_mode and do_classifier_free_guidance:
# # Infer ControlNet only for the conditional batch.
# latent_control_input = latents
# latent_control_input = self.scheduler.scale_model_input(control_model_input, t)
# controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
# else:
# control_model_input = unet_latent_input
# controlnet_prompt_embeds = prompt_embeds
# controlnet(s) inference # controlnet(s) inference
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=control_latent_input, sample=control_latent_input,
timestep=timestep, timestep=timestep,
# encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
# conditioning_data.text_embeddings]),
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
# cross_attention_kwargs, guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
# guess_mode=guess_mode,
guess_mode=control_datum.soft_injection,
return_dict=False, return_dict=False,
) )
print("finished ControlNetModel() call, step", step_index) if cfg_injection:
if guess_mode:
# Inferred ControlNet only for the conditional batch. # Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
@ -765,14 +730,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
] ]
mid_block_res_sample += mid_sample mid_block_res_sample += mid_sample
# guess mode handling from diffusers controlnet pipeline:
# if guess_mode and do_classifier_free_guidance:
# # Inferred ControlNet only for the conditional batch.
# # To apply the output of ControlNet to both the unconditional and conditional batches,
# # add 0 to the unconditional batch to keep it unchanged.
# down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
# mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
x=unet_latent_input, x=unet_latent_input,
@ -1103,9 +1060,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device="cuda", device="cuda",
dtype=torch.float16, dtype=torch.float16,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
guess_mode=False, control_mode="balanced"
soft_injection=False,
cfg_injection=False,
): ):
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
@ -1136,7 +1091,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0) image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
# if do_classifier_free_guidance and not guess_mode: cfg_injection = (control_mode == "more_control" or control_mode == "even_more_control")
if do_classifier_free_guidance and not cfg_injection: if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2) image = torch.cat([image] * 2)
return image return image