mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename as suggested in other PRs
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
42356ec866
commit
3cb13d6288
@ -31,25 +31,25 @@ class ControlNetExt(ExtensionBase):
|
||||
resize_mode: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.image = image
|
||||
self.weight = weight
|
||||
self.begin_step_percent = begin_step_percent
|
||||
self.end_step_percent = end_step_percent
|
||||
self.control_mode = control_mode
|
||||
self.resize_mode = resize_mode
|
||||
self._model = model
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
self._control_mode = control_mode
|
||||
self._resize_mode = resize_mode
|
||||
|
||||
self.image_tensor: Optional[torch.Tensor] = None
|
||||
self._image_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
try:
|
||||
original_processors = self.model.attn_processors
|
||||
self.model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||
original_processors = self._model.attn_processors
|
||||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self.model.set_attn_processor(original_processors)
|
||||
self._model.set_attn_processor(original_processors)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def resize_image(self, ctx: DenoiseContext):
|
||||
@ -57,8 +57,8 @@ class ControlNetExt(ExtensionBase):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
self.image_tensor = prepare_control_image(
|
||||
image=self.image,
|
||||
self._image_tensor = prepare_control_image(
|
||||
image=self._image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
@ -66,22 +66,22 @@ class ControlNetExt(ExtensionBase):
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=ctx.latents.device,
|
||||
dtype=ctx.latents.dtype,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
control_mode=self._control_mode,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self.begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self.end_step_percent * total_steps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
# convert mode to internal flags
|
||||
soft_injection = self.control_mode in ["more_prompt", "more_control"]
|
||||
cfg_injection = self.control_mode in ["more_control", "unbalanced"]
|
||||
soft_injection = self._control_mode in ["more_prompt", "more_control"]
|
||||
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
|
||||
|
||||
# no negative conditioning in cfg_injection mode
|
||||
if cfg_injection:
|
||||
@ -117,7 +117,7 @@ class ControlNetExt(ExtensionBase):
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
|
||||
model_input = ctx.latent_model_input
|
||||
image_tensor = self.image_tensor
|
||||
image_tensor = self._image_tensor
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
model_input = torch.cat([model_input] * 2)
|
||||
image_tensor = torch.cat([image_tensor] * 2)
|
||||
@ -134,7 +134,7 @@ class ControlNetExt(ExtensionBase):
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
||||
|
||||
# get static weight, or weight corresponding to current step
|
||||
weight = self.weight
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
@ -144,7 +144,7 @@ class ControlNetExt(ExtensionBase):
|
||||
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = self.model(
|
||||
down_samples, mid_sample = self._model(
|
||||
controlnet_cond=image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
|
Loading…
Reference in New Issue
Block a user