From de3e6cdb02a59a809746fbdde42b4378f5cbd676 Mon Sep 17 00:00:00 2001 From: user1 Date: Tue, 13 Jun 2023 21:08:34 -0700 Subject: [PATCH] Switched over to ControlNet control_mode with 4 options: balanced, more_prompt, more_control, even_more_control. Based on True/False combinations of internal booleans cfg_injection and soft_injection --- .../controlnet_image_processors.py | 17 ++-- invokeai/app/invocations/latent.py | 13 +-- .../stable_diffusion/diffusers_pipeline.py | 85 +++++-------------- 3 files changed, 27 insertions(+), 88 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 84e18e69bd..c433e90648 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -94,7 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [ ] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] -# CONTROLNET_MODE_VALUES = Literal[tuple(["BALANCED", "PROMPT", "CONTROL"])] +CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "even_more_control"])] class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") @@ -105,9 +105,8 @@ class ControlField(BaseModel): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - # guess_mode: bool = Field(default=False, description="Toggle for guess mode") - cfg_injection: bool = Field(default=False, description="Toggle for cfg injection") - soft_injection: bool = Field(default=False, description="Toggle for soft injection") + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use") + @validator("control_weight") def abs_le_one(cls, v): """validate that all abs(values) are <=1""" @@ -148,14 +147,11 @@ class ControlNetInvocation(BaseInvocation): control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") - # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode begin_step_percent: float = Field(default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - # guess_mode: bool = Field(default=False, description="Toggle for guess mode") - cfg_injection: bool = Field(default=False, description="Toggle for cfg injection") - soft_injection: bool = Field(default=False, description="Toggle for soft injection") + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") # fmt: on class Config(InvocationConfig): @@ -173,7 +169,6 @@ class ControlNetInvocation(BaseInvocation): } def invoke(self, context: InvocationContext) -> ControlOutput: - return ControlOutput( control=ControlField( image=self.image, @@ -181,9 +176,7 @@ class ControlNetInvocation(BaseInvocation): control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, - # guess_mode=self.guess_mode, - cfg_injection=self.cfg_injection, - soft_injection=self.soft_injection, + control_mode=self.control_mode, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b067118010..a712b027c0 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -282,19 +282,14 @@ class TextToLatentsInvocation(BaseInvocation): control_height_resize = latents_shape[2] * 8 control_width_resize = latents_shape[3] * 8 if control_input is None: - # print("control input is None") control_list = None elif isinstance(control_input, list) and len(control_input) == 0: - # print("control input is empty list") control_list = None elif isinstance(control_input, ControlField): - # print("control input is ControlField") control_list = [control_input] elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): - # print("control input is list[ControlField]") control_list = control_input else: - # print("input control is unrecognized:", type(self.control)) control_list = None if (control_list is None): control_data = None @@ -337,18 +332,14 @@ class TextToLatentsInvocation(BaseInvocation): # num_images_per_prompt=num_images_per_prompt, device=control_model.device, dtype=control_model.dtype, - # guess_mode=control_info.guess_mode, - cfg_injection=control_info.cfg_injection, - soft_injection=control_info.soft_injection, + control_mode=control_info.control_mode, ) control_item = ControlNetData(model=control_model, image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, - # guess_mode=control_info.guess_mode, - cfg_injection=control_info.cfg_injection, - soft_injection=control_info.soft_injection, + control_mode=control_info.control_mode, ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 9f03788657..58ea30ef48 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -221,10 +221,8 @@ class ControlNetData: weight: Union[float, List[float]] = Field(default=1.0) begin_step_percent: float = Field(default=0.0) end_step_percent: float = Field(default=1.0) - # FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL - guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt - cfg_injection: bool = Field(default=False) - soft_injection: bool = Field(default=False) + control_mode: str = Field(default="balanced") + @dataclass(frozen=True) class ConditioningData: @@ -662,44 +660,30 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent unet_latent_input = self.scheduler.scale_model_input(latents, timestep) - # # guess mode handling from diffusers - # if guess_mode and do_classifier_free_guidance: - # # Infer ControlNet only for the conditional batch. - # control_model_input = latents - # control_model_input = self.scheduler.scale_model_input(control_model_input, t) - # controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - # else: - # control_model_input = unet_latent_input - # controlnet_prompt_embeds = prompt_embeds - # default is no controlnet, so set controlnet processing output to None down_block_res_samples, mid_block_res_sample = None, None if control_data is not None: - # FIXME: make sure guidance_scale <= 1.0 is handled correctly if doing per-step guidance setting - # UPDATE: I think this is fixed now with pydantic validator for cfg_scale? - # So we should _never_ have guidance_scale <= 1.0 - # if conditioning_data.guidance_scale > 1.0: - # if conditioning_data.guidance_scale is not None: - # if guess_mode is False: - # # expand the latents input to control model if doing classifier free guidance - # # (which I think for now is always true, there is conditional elsewhere that stops execution if - # # classifier_free_guidance is <= 1.0 ?) - # control_latent_input = torch.cat([unet_latent_input] * 2) - # else: - # control_latent_input = unet_latent_input # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) for i, control_datum in enumerate(control_data): - # print("controlnet", i, "==>", type(control_datum)) + control_mode = control_datum.control_mode + # soft_injection and cfg_injection are the two ControlNet control_mode booleans + # that are combined at higher level to make control_mode enum + # soft_injection determines whether to do per-layer re-weighting adjustment (if True) + # or default weighting (if False) + soft_injection = (control_mode == "more_prompt" or control_mode == "more_control") + # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) + # or the default both conditional and unconditional (if False) + cfg_injection = (control_mode == "more_control" or control_mode == "even_more_control") + first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) # only apply controlnet if current step is within the controlnet's begin/end step range if step_index >= first_control_step and step_index <= last_control_step: - # guess_mode = control_datum.guess_mode - guess_mode = control_datum.cfg_injection - if guess_mode: + + if cfg_injection: control_latent_input = unet_latent_input else: # expand the latents input to control model if doing classifier free guidance @@ -707,15 +691,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # classifier_free_guidance is <= 1.0 ?) control_latent_input = torch.cat([unet_latent_input] * 2) - print("running controlnet", i, "for step", step_index) - print("guess mode: ", guess_mode) - print("guess mode type: ", type(guess_mode)) - if guess_mode: # only using prompt conditioning in unconditioned + if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings]) else: encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings]) - print("encoder_hidden_states.shape", encoder_hidden_states.shape) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step controlnet_weight = control_datum.weight[step_index] @@ -723,35 +703,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # if controlnet has a single weight, use it for all steps controlnet_weight = control_datum.weight - # guess mode handling from diffusers controlnet pipeline: - # if guess_mode and do_classifier_free_guidance: - # # Infer ControlNet only for the conditional batch. - # latent_control_input = latents - # latent_control_input = self.scheduler.scale_model_input(control_model_input, t) - # controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - # else: - # control_model_input = unet_latent_input - # controlnet_prompt_embeds = prompt_embeds - # controlnet(s) inference down_samples, mid_sample = control_datum.model( sample=control_latent_input, timestep=timestep, - # encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, - # conditioning_data.text_embeddings]), encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_datum.image_tensor, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale - # cross_attention_kwargs, - # guess_mode=guess_mode, - guess_mode=control_datum.soft_injection, + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel return_dict=False, ) - print("finished ControlNetModel() call, step", step_index) - if guess_mode: + if cfg_injection: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. + # add 0 to the unconditional batch to keep it unchanged. down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) @@ -765,14 +730,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ] mid_block_res_sample += mid_sample - # guess mode handling from diffusers controlnet pipeline: - # if guess_mode and do_classifier_free_guidance: - # # Inferred ControlNet only for the conditional batch. - # # To apply the output of ControlNet to both the unconditional and conditional batches, - # # add 0 to the unconditional batch to keep it unchanged. - # down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - # mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step( x=unet_latent_input, @@ -1103,9 +1060,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): device="cuda", dtype=torch.float16, do_classifier_free_guidance=True, - guess_mode=False, - soft_injection=False, - cfg_injection=False, + control_mode="balanced" ): if not isinstance(image, torch.Tensor): @@ -1136,7 +1091,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) - # if do_classifier_free_guidance and not guess_mode: + cfg_injection = (control_mode == "more_control" or control_mode == "even_more_control") if do_classifier_free_guidance and not cfg_injection: image = torch.cat([image] * 2) return image