mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename as suggested in other PRs
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
42356ec866
commit
3cb13d6288
@ -31,25 +31,25 @@ class ControlNetExt(ExtensionBase):
|
|||||||
resize_mode: str,
|
resize_mode: str,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self._model = model
|
||||||
self.image = image
|
self._image = image
|
||||||
self.weight = weight
|
self._weight = weight
|
||||||
self.begin_step_percent = begin_step_percent
|
self._begin_step_percent = begin_step_percent
|
||||||
self.end_step_percent = end_step_percent
|
self._end_step_percent = end_step_percent
|
||||||
self.control_mode = control_mode
|
self._control_mode = control_mode
|
||||||
self.resize_mode = resize_mode
|
self._resize_mode = resize_mode
|
||||||
|
|
||||||
self.image_tensor: Optional[torch.Tensor] = None
|
self._image_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_extension(self, ctx: DenoiseContext):
|
def patch_extension(self, ctx: DenoiseContext):
|
||||||
try:
|
try:
|
||||||
original_processors = self.model.attn_processors
|
original_processors = self._model.attn_processors
|
||||||
self.model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
self.model.set_attn_processor(original_processors)
|
self._model.set_attn_processor(original_processors)
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
def resize_image(self, ctx: DenoiseContext):
|
def resize_image(self, ctx: DenoiseContext):
|
||||||
@ -57,8 +57,8 @@ class ControlNetExt(ExtensionBase):
|
|||||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||||
|
|
||||||
self.image_tensor = prepare_control_image(
|
self._image_tensor = prepare_control_image(
|
||||||
image=self.image,
|
image=self._image,
|
||||||
do_classifier_free_guidance=False,
|
do_classifier_free_guidance=False,
|
||||||
width=image_width,
|
width=image_width,
|
||||||
height=image_height,
|
height=image_height,
|
||||||
@ -66,22 +66,22 @@ class ControlNetExt(ExtensionBase):
|
|||||||
# num_images_per_prompt=num_images_per_prompt,
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
device=ctx.latents.device,
|
device=ctx.latents.device,
|
||||||
dtype=ctx.latents.dtype,
|
dtype=ctx.latents.dtype,
|
||||||
control_mode=self.control_mode,
|
control_mode=self._control_mode,
|
||||||
resize_mode=self.resize_mode,
|
resize_mode=self._resize_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_UNET)
|
@callback(ExtensionCallbackType.PRE_UNET)
|
||||||
def pre_unet_step(self, ctx: DenoiseContext):
|
def pre_unet_step(self, ctx: DenoiseContext):
|
||||||
# skip if model not active in current step
|
# skip if model not active in current step
|
||||||
total_steps = len(ctx.inputs.timesteps)
|
total_steps = len(ctx.inputs.timesteps)
|
||||||
first_step = math.floor(self.begin_step_percent * total_steps)
|
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||||
last_step = math.ceil(self.end_step_percent * total_steps)
|
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||||
return
|
return
|
||||||
|
|
||||||
# convert mode to internal flags
|
# convert mode to internal flags
|
||||||
soft_injection = self.control_mode in ["more_prompt", "more_control"]
|
soft_injection = self._control_mode in ["more_prompt", "more_control"]
|
||||||
cfg_injection = self.control_mode in ["more_control", "unbalanced"]
|
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
|
||||||
|
|
||||||
# no negative conditioning in cfg_injection mode
|
# no negative conditioning in cfg_injection mode
|
||||||
if cfg_injection:
|
if cfg_injection:
|
||||||
@ -117,7 +117,7 @@ class ControlNetExt(ExtensionBase):
|
|||||||
total_steps = len(ctx.inputs.timesteps)
|
total_steps = len(ctx.inputs.timesteps)
|
||||||
|
|
||||||
model_input = ctx.latent_model_input
|
model_input = ctx.latent_model_input
|
||||||
image_tensor = self.image_tensor
|
image_tensor = self._image_tensor
|
||||||
if conditioning_mode == ConditioningMode.Both:
|
if conditioning_mode == ConditioningMode.Both:
|
||||||
model_input = torch.cat([model_input] * 2)
|
model_input = torch.cat([model_input] * 2)
|
||||||
image_tensor = torch.cat([image_tensor] * 2)
|
image_tensor = torch.cat([image_tensor] * 2)
|
||||||
@ -134,7 +134,7 @@ class ControlNetExt(ExtensionBase):
|
|||||||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
||||||
|
|
||||||
# get static weight, or weight corresponding to current step
|
# get static weight, or weight corresponding to current step
|
||||||
weight = self.weight
|
weight = self._weight
|
||||||
if isinstance(weight, list):
|
if isinstance(weight, list):
|
||||||
weight = weight[ctx.step_index]
|
weight = weight[ctx.step_index]
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ class ControlNetExt(ExtensionBase):
|
|||||||
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
|
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
|
||||||
|
|
||||||
# controlnet(s) inference
|
# controlnet(s) inference
|
||||||
down_samples, mid_sample = self.model(
|
down_samples, mid_sample = self._model(
|
||||||
controlnet_cond=image_tensor,
|
controlnet_cond=image_tensor,
|
||||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
|
Loading…
Reference in New Issue
Block a user