From f613c073c1650234472b6095ee22dca3560ff549 Mon Sep 17 00:00:00 2001 From: user1 Date: Fri, 12 May 2023 04:01:35 -0700 Subject: [PATCH] Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) --- .../controlnet_image_processors.py | 25 ++++----- invokeai/app/invocations/latent.py | 4 +- .../stable_diffusion/diffusers_pipeline.py | 54 ++++++++++--------- 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 169dfc96c4..6cb9a73976 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -33,16 +33,16 @@ from .image import ImageOutput, build_image_output, PILInvocationConfig class ControlField(BaseModel): image: ImageField = Field(default=None, description="processed image") - # width: Optional[int] = Field(default=None, description="The width of the image in pixels") - # height: Optional[int] = Field(default=None, description="The height of the image in pixels") - # mode: Optional[str] = Field(default=None, description="The mode of the image") control_model: Optional[str] = Field(default=None, description="control model used") control_weight: Optional[float] = Field(default=None, description="weight given to controlnet") + begin_step_percent: float = Field(default=0, ge=0, le=1, + description="% of total steps at which controlnet is first applied") + end_step_percent: float = Field(default=1, ge=0, le=1, + description="% of total steps at which controlnet is last applied") class Config: schema_extra = { - "required": ["image", "control_model", "control_weight"] - # "required": ["type", "image", "width", "height", "mode"] + "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] } @@ -62,12 +62,11 @@ class ControlNetInvocation(BaseInvocation): image: ImageField = Field(default=None, description="image to process") control_model: str = Field(default=None, description="control model to use") control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet") - # TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs) - # begin_step_percent: float = Field(default=0, ge=0, le=1, - # description="% of total steps at which controlnet is first applied") - # end_step_percent: float = Field(default=1, ge=0, le=1, - # description="% of total steps at which controlnet is last applied") - # guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)") + # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode + begin_step_percent: float = Field(default=0, ge=0, le=1, + description="% of total steps at which controlnet is first applied") + end_step_percent: float = Field(default=1, ge=0, le=1, + description="% of total steps at which controlnet is last applied") # fmt: on @@ -77,7 +76,9 @@ class ControlNetInvocation(BaseInvocation): control=ControlField( image=self.image, control_model=self.control_model, - control_weight=self.control_weight + control_weight=self.control_weight, + begin_step_percent=self.begin_step_percent, + end_step_percent=self.end_step_percent, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b9ce5c10a8..d644874c17 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -317,7 +317,9 @@ class TextToLatentsInvocation(BaseInvocation): ) control_item = ControlNetData(model=control_model, image_tensor=control_image, - weight=control_info.control_weight) + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent) control_data.append(control_item) # multi_control = MultiControlNetModel(control_models) # model.control_model = multi_control diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index a6c365967c..0e685e5a9a 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -219,7 +219,8 @@ class ControlNetData: model: ControlNetModel = Field(default=None) image_tensor: torch.Tensor= Field(default=None) weight: float = Field(default=1.0) - + begin_step_percent: float = Field(default=0.0) + end_step_percent: float = Field(default=1.0) @dataclass(frozen=True) class ConditioningData: @@ -657,7 +658,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) - # if (self.control_model is not None) and (control_image is not None): + # default is no controlnet, so set controlnet processing output to None + down_block_res_samples, mid_block_res_sample = None, None + if control_data is not None: if conditioning_data.guidance_scale > 1.0: # expand the latents input to control model if doing classifier free guidance @@ -671,28 +674,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # and MultiControlNet (multiple ControlNetData in list) for i, control_datum in enumerate(control_data): # print("controlnet", i, "==>", type(control_datum)) - down_samples, mid_sample = control_datum.model( - sample=latent_control_input, - timestep=timestep, - encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings]), - controlnet_cond=control_datum.image_tensor, - conditioning_scale=control_datum.weight, - # cross_attention_kwargs, - guess_mode=False, - return_dict=False, - ) - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - # add controlnet outputs together if have multiple controlnets - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) - ] - mid_block_res_sample += mid_sample - else: - down_block_res_samples, mid_block_res_sample = None, None + first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) + last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) + # apply_control_this_step = step_index >= first_control_step and step_index <= last_control_step + if step_index >= first_control_step and step_index <= last_control_step: + print("running controlnet", i, "for step", step_index) + down_samples, mid_sample = control_datum.model( + sample=latent_control_input, + timestep=timestep, + encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings]), + controlnet_cond=control_datum.image_tensor, + conditioning_scale=control_datum.weight, + # cross_attention_kwargs, + guess_mode=False, + return_dict=False, + ) + if down_block_res_samples is None and mid_block_res_sample is None: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + # add controlnet outputs together if have multiple controlnets + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step(