Rename as suggested in other PRs

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-23 01:01:18 +03:00
parent 42356ec866
commit 3cb13d6288

View File

@ -31,25 +31,25 @@ class ControlNetExt(ExtensionBase):
resize_mode: str,
):
super().__init__()
self.model = model
self.image = image
self.weight = weight
self.begin_step_percent = begin_step_percent
self.end_step_percent = end_step_percent
self.control_mode = control_mode
self.resize_mode = resize_mode
self._model = model
self._image = image
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
self._control_mode = control_mode
self._resize_mode = resize_mode
self.image_tensor: Optional[torch.Tensor] = None
self._image_tensor: Optional[torch.Tensor] = None
@contextmanager
def patch_extension(self, ctx: DenoiseContext):
try:
original_processors = self.model.attn_processors
self.model.set_attn_processor(ctx.inputs.attention_processor_cls())
original_processors = self._model.attn_processors
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
yield None
finally:
self.model.set_attn_processor(original_processors)
self._model.set_attn_processor(original_processors)
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def resize_image(self, ctx: DenoiseContext):
@ -57,8 +57,8 @@ class ControlNetExt(ExtensionBase):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
self.image_tensor = prepare_control_image(
image=self.image,
self._image_tensor = prepare_control_image(
image=self._image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
@ -66,22 +66,22 @@ class ControlNetExt(ExtensionBase):
# num_images_per_prompt=num_images_per_prompt,
device=ctx.latents.device,
dtype=ctx.latents.dtype,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
control_mode=self._control_mode,
resize_mode=self._resize_mode,
)
@callback(ExtensionCallbackType.PRE_UNET)
def pre_unet_step(self, ctx: DenoiseContext):
# skip if model not active in current step
total_steps = len(ctx.inputs.timesteps)
first_step = math.floor(self.begin_step_percent * total_steps)
last_step = math.ceil(self.end_step_percent * total_steps)
first_step = math.floor(self._begin_step_percent * total_steps)
last_step = math.ceil(self._end_step_percent * total_steps)
if ctx.step_index < first_step or ctx.step_index > last_step:
return
# convert mode to internal flags
soft_injection = self.control_mode in ["more_prompt", "more_control"]
cfg_injection = self.control_mode in ["more_control", "unbalanced"]
soft_injection = self._control_mode in ["more_prompt", "more_control"]
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
# no negative conditioning in cfg_injection mode
if cfg_injection:
@ -117,7 +117,7 @@ class ControlNetExt(ExtensionBase):
total_steps = len(ctx.inputs.timesteps)
model_input = ctx.latent_model_input
image_tensor = self.image_tensor
image_tensor = self._image_tensor
if conditioning_mode == ConditioningMode.Both:
model_input = torch.cat([model_input] * 2)
image_tensor = torch.cat([image_tensor] * 2)
@ -134,7 +134,7 @@ class ControlNetExt(ExtensionBase):
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
# get static weight, or weight corresponding to current step
weight = self.weight
weight = self._weight
if isinstance(weight, list):
weight = weight[ctx.step_index]
@ -144,7 +144,7 @@ class ControlNetExt(ExtensionBase):
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
# controlnet(s) inference
down_samples, mid_sample = self.model(
down_samples, mid_sample = self._model(
controlnet_cond=image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel