From fd715026a73db6f81d3f86dd6a557155ef60fdaf Mon Sep 17 00:00:00 2001 From: user1 Date: Sun, 11 Jun 2023 02:00:39 -0700 Subject: [PATCH] First pass at ControlNet "guess mode" implementation. --- .../controlnet_image_processors.py | 6 +- invokeai/app/invocations/latent.py | 4 +- .../stable_diffusion/diffusers_pipeline.py | 108 ++++++++++++++---- 3 files changed, 93 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index b32afe4941..dc172d9270 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -1,7 +1,7 @@ # InvokeAI nodes for ControlNet image preprocessors # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux -from builtins import float +from builtins import float, bool import numpy as np from typing import Literal, Optional, Union, List @@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [ ] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] +# CONTROLNET_MODE_VALUES = Literal[tuple(["BALANCED", "PROMPT", "CONTROL"])] class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") @@ -104,6 +105,7 @@ class ControlField(BaseModel): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") + guess_mode: bool = Field(default=False, description="Toggle for guess mode") @validator("control_weight") def abs_le_one(cls, v): """validate that all abs(values) are <=1""" @@ -149,6 +151,7 @@ class ControlNetInvocation(BaseInvocation): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") + guess_mode: bool = Field(default=False, description="Toggle for guess mode") # fmt: on class Config(InvocationConfig): @@ -174,6 +177,7 @@ class ControlNetInvocation(BaseInvocation): control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, + guess_mode=self.guess_mode, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index dbd419b6e5..7b7cced33f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -337,12 +337,14 @@ class TextToLatentsInvocation(BaseInvocation): # num_images_per_prompt=num_images_per_prompt, device=control_model.device, dtype=control_model.dtype, + guess_mode=control_info.guess_mode, ) control_item = ControlNetData(model=control_model, image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent) + end_step_percent=control_info.end_step_percent, + guess_mode=control_info.guess_mode,) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 6a11891979..52880b5e3f 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -217,10 +217,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): @dataclass class ControlNetData: model: ControlNetModel = Field(default=None) - image_tensor: torch.Tensor= Field(default=None) - weight: Union[float, List[float]]= Field(default=1.0) + image_tensor: torch.Tensor = Field(default=None) + weight: Union[float, List[float]] = Field(default=1.0) begin_step_percent: float = Field(default=0.0) end_step_percent: float = Field(default=1.0) + # FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL + guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt @dataclass(frozen=True) class ConditioningData: @@ -656,21 +658,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # TODO: should this scaling happen here or inside self._unet_forward? # i.e. before or after passing it to InvokeAIDiffuserComponent - latent_model_input = self.scheduler.scale_model_input(latents, timestep) + unet_latent_input = self.scheduler.scale_model_input(latents, timestep) + + # # guess mode handling from diffusers + # if guess_mode and do_classifier_free_guidance: + # # Infer ControlNet only for the conditional batch. + # control_model_input = latents + # control_model_input = self.scheduler.scale_model_input(control_model_input, t) + # controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + # else: + # control_model_input = unet_latent_input + # controlnet_prompt_embeds = prompt_embeds # default is no controlnet, so set controlnet processing output to None down_block_res_samples, mid_block_res_sample = None, None if control_data is not None: - # FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting + # FIXME: make sure guidance_scale <= 1.0 is handled correctly if doing per-step guidance setting + # UPDATE: I think this is fixed now with pydantic validator for cfg_scale? + # So we should _never_ have guidance_scale <= 1.0 # if conditioning_data.guidance_scale > 1.0: - if conditioning_data.guidance_scale is not None: - # expand the latents input to control model if doing classifier free guidance - # (which I think for now is always true, there is conditional elsewhere that stops execution if - # classifier_free_guidance is <= 1.0 ?) - latent_control_input = torch.cat([latent_model_input] * 2) - else: - latent_control_input = latent_model_input + # if conditioning_data.guidance_scale is not None: + # if guess_mode is False: + # # expand the latents input to control model if doing classifier free guidance + # # (which I think for now is always true, there is conditional elsewhere that stops execution if + # # classifier_free_guidance is <= 1.0 ?) + # control_latent_input = torch.cat([unet_latent_input] * 2) + # else: + # control_latent_input = unet_latent_input # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) @@ -680,24 +695,62 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) # only apply controlnet if current step is within the controlnet's begin/end step range if step_index >= first_control_step and step_index <= last_control_step: - # print("running controlnet", i, "for step", step_index) + guess_mode = control_datum.guess_mode + if guess_mode: + control_latent_input = unet_latent_input + else: + # expand the latents input to control model if doing classifier free guidance + # (which I think for now is always true, there is conditional elsewhere that stops execution if + # classifier_free_guidance is <= 1.0 ?) + control_latent_input = torch.cat([unet_latent_input] * 2) + + print("running controlnet", i, "for step", step_index) + print("guess mode: ", guess_mode) + print("guess mode type: ", type(guess_mode)) + if guess_mode: # only using prompt conditioning in unconditioned + encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings]) + else: + encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings]) + print("encoder_hidden_states.shape", encoder_hidden_states.shape) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step controlnet_weight = control_datum.weight[step_index] else: # if controlnet has a single weight, use it for all steps controlnet_weight = control_datum.weight + + # guess mode handling from diffusers controlnet pipeline: + # if guess_mode and do_classifier_free_guidance: + # # Infer ControlNet only for the conditional batch. + # latent_control_input = latents + # latent_control_input = self.scheduler.scale_model_input(control_model_input, t) + # controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + # else: + # control_model_input = unet_latent_input + # controlnet_prompt_embeds = prompt_embeds + + # controlnet(s) inference down_samples, mid_sample = control_datum.model( - sample=latent_control_input, + sample=control_latent_input, timestep=timestep, - encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings]), + # encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, + # conditioning_data.text_embeddings]), + encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_datum.image_tensor, - conditioning_scale=controlnet_weight, + conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale # cross_attention_kwargs, - guess_mode=False, + guess_mode=guess_mode, return_dict=False, ) + print("finished ControlNetModel() call, step", step_index) + if guess_mode: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + if down_block_res_samples is None and mid_block_res_sample is None: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: @@ -708,13 +761,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ] mid_block_res_sample += mid_sample + # guess mode handling from diffusers controlnet pipeline: + # if guess_mode and do_classifier_free_guidance: + # # Inferred ControlNet only for the conditional batch. + # # To apply the output of ControlNet to both the unconditional and conditional batches, + # # add 0 to the unconditional batch to keep it unchanged. + # down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + # mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step( - latent_model_input, - t, - conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings, - conditioning_data.guidance_scale, + x=unet_latent_input, + sigma=t, + unconditioning=conditioning_data.unconditioned_embeddings, + conditioning=conditioning_data.text_embeddings, + unconditional_guidance_scale=conditioning_data.guidance_scale, step_index=step_index, total_step_count=total_step_count, down_block_additional_residuals=down_block_res_samples, # from controlnet(s) @@ -1038,6 +1099,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): device="cuda", dtype=torch.float16, do_classifier_free_guidance=True, + guess_mode=False, ): if not isinstance(image, torch.Tensor): @@ -1068,6 +1130,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance: + if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image