From 9e0e26f4c4eb347cf360aa8d05724037735fd0ca Mon Sep 17 00:00:00 2001 From: user1 Date: Mon, 12 Jun 2023 23:57:57 -0700 Subject: [PATCH] Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes --- .../app/invocations/controlnet_image_processors.py | 12 +++++++++--- invokeai/app/invocations/latent.py | 9 +++++++-- .../backend/stable_diffusion/diffusers_pipeline.py | 13 ++++++++++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index dc172d9270..84e18e69bd 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -105,7 +105,9 @@ class ControlField(BaseModel): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - guess_mode: bool = Field(default=False, description="Toggle for guess mode") + # guess_mode: bool = Field(default=False, description="Toggle for guess mode") + cfg_injection: bool = Field(default=False, description="Toggle for cfg injection") + soft_injection: bool = Field(default=False, description="Toggle for soft injection") @validator("control_weight") def abs_le_one(cls, v): """validate that all abs(values) are <=1""" @@ -151,7 +153,9 @@ class ControlNetInvocation(BaseInvocation): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - guess_mode: bool = Field(default=False, description="Toggle for guess mode") + # guess_mode: bool = Field(default=False, description="Toggle for guess mode") + cfg_injection: bool = Field(default=False, description="Toggle for cfg injection") + soft_injection: bool = Field(default=False, description="Toggle for soft injection") # fmt: on class Config(InvocationConfig): @@ -177,7 +181,9 @@ class ControlNetInvocation(BaseInvocation): control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, - guess_mode=self.guess_mode, + # guess_mode=self.guess_mode, + cfg_injection=self.cfg_injection, + soft_injection=self.soft_injection, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 7b7cced33f..104d1003d0 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -337,14 +337,19 @@ class TextToLatentsInvocation(BaseInvocation): # num_images_per_prompt=num_images_per_prompt, device=control_model.device, dtype=control_model.dtype, - guess_mode=control_info.guess_mode, + # guess_mode=control_info.guess_mode, + cfg_injection=control_info.cfg_injection, + soft_injection=control_info.soft_injection, ) control_item = ControlNetData(model=control_model, image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, - guess_mode=control_info.guess_mode,) + # guess_mode=control_info.guess_mode, + cfg_injection=control_info.cfg_injection, + soft_injection=control_info.soft_injection, + ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 52880b5e3f..5b6848d8ad 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -223,6 +223,8 @@ class ControlNetData: end_step_percent: float = Field(default=1.0) # FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt + cfg_injection: bool = Field(default=False) + soft_injection: bool = Field(default=False) @dataclass(frozen=True) class ConditioningData: @@ -695,7 +697,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) # only apply controlnet if current step is within the controlnet's begin/end step range if step_index >= first_control_step and step_index <= last_control_step: - guess_mode = control_datum.guess_mode + # guess_mode = control_datum.guess_mode + guess_mode = control_datum.cfg_injection if guess_mode: control_latent_input = unet_latent_input else: @@ -740,7 +743,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): controlnet_cond=control_datum.image_tensor, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale # cross_attention_kwargs, - guess_mode=guess_mode, + # guess_mode=guess_mode, + guess_mode=control_datum.soft_injection, return_dict=False, ) print("finished ControlNetModel() call, step", step_index) @@ -1100,6 +1104,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): dtype=torch.float16, do_classifier_free_guidance=True, guess_mode=False, + soft_injection=False, + cfg_injection=False, ): if not isinstance(image, torch.Tensor): @@ -1130,6 +1136,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance and not guess_mode: + # if do_classifier_free_guidance and not guess_mode: + if do_classifier_free_guidance and not cfg_injection: image = torch.cat([image] * 2) return image