diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py index e74d183c2c..0506a7f1a3 100644 --- a/invokeai/backend/stable_diffusion/extensions/controlnet.py +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -31,25 +31,25 @@ class ControlNetExt(ExtensionBase): resize_mode: str, ): super().__init__() - self.model = model - self.image = image - self.weight = weight - self.begin_step_percent = begin_step_percent - self.end_step_percent = end_step_percent - self.control_mode = control_mode - self.resize_mode = resize_mode + self._model = model + self._image = image + self._weight = weight + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + self._control_mode = control_mode + self._resize_mode = resize_mode - self.image_tensor: Optional[torch.Tensor] = None + self._image_tensor: Optional[torch.Tensor] = None @contextmanager def patch_extension(self, ctx: DenoiseContext): try: - original_processors = self.model.attn_processors - self.model.set_attn_processor(ctx.inputs.attention_processor_cls()) + original_processors = self._model.attn_processors + self._model.set_attn_processor(ctx.inputs.attention_processor_cls()) yield None finally: - self.model.set_attn_processor(original_processors) + self._model.set_attn_processor(original_processors) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def resize_image(self, ctx: DenoiseContext): @@ -57,8 +57,8 @@ class ControlNetExt(ExtensionBase): image_height = latent_height * LATENT_SCALE_FACTOR image_width = latent_width * LATENT_SCALE_FACTOR - self.image_tensor = prepare_control_image( - image=self.image, + self._image_tensor = prepare_control_image( + image=self._image, do_classifier_free_guidance=False, width=image_width, height=image_height, @@ -66,22 +66,22 @@ class ControlNetExt(ExtensionBase): # num_images_per_prompt=num_images_per_prompt, device=ctx.latents.device, dtype=ctx.latents.dtype, - control_mode=self.control_mode, - resize_mode=self.resize_mode, + control_mode=self._control_mode, + resize_mode=self._resize_mode, ) @callback(ExtensionCallbackType.PRE_UNET) def pre_unet_step(self, ctx: DenoiseContext): # skip if model not active in current step total_steps = len(ctx.inputs.timesteps) - first_step = math.floor(self.begin_step_percent * total_steps) - last_step = math.ceil(self.end_step_percent * total_steps) + first_step = math.floor(self._begin_step_percent * total_steps) + last_step = math.ceil(self._end_step_percent * total_steps) if ctx.step_index < first_step or ctx.step_index > last_step: return # convert mode to internal flags - soft_injection = self.control_mode in ["more_prompt", "more_control"] - cfg_injection = self.control_mode in ["more_control", "unbalanced"] + soft_injection = self._control_mode in ["more_prompt", "more_control"] + cfg_injection = self._control_mode in ["more_control", "unbalanced"] # no negative conditioning in cfg_injection mode if cfg_injection: @@ -117,7 +117,7 @@ class ControlNetExt(ExtensionBase): total_steps = len(ctx.inputs.timesteps) model_input = ctx.latent_model_input - image_tensor = self.image_tensor + image_tensor = self._image_tensor if conditioning_mode == ConditioningMode.Both: model_input = torch.cat([model_input] * 2) image_tensor = torch.cat([image_tensor] * 2) @@ -134,7 +134,7 @@ class ControlNetExt(ExtensionBase): ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) # get static weight, or weight corresponding to current step - weight = self.weight + weight = self._weight if isinstance(weight, list): weight = weight[ctx.step_index] @@ -144,7 +144,7 @@ class ControlNetExt(ExtensionBase): tmp_kwargs.pop("down_intrablock_additional_residuals", None) # controlnet(s) inference - down_samples, mid_sample = self.model( + down_samples, mid_sample = self._model( controlnet_cond=image_tensor, conditioning_scale=weight, # controlnet specific, NOT the guidance scale guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel