Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-27 13:18:28 +03:00
parent 5810cee6c9
commit ed0174fbc6
2 changed files with 5 additions and 8 deletions

View File

@ -94,13 +94,13 @@ class InpaintExt(ExtensionBase):
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
# TODO: order value
# Use negative order to make extensions with default order work with patched latents
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
# TODO: order value
# TODO: redo this with preview events rewrite
# Use negative order to make extensions with default order work with patched latents
@callback(ExtensionCallbackType.POST_STEP, order=-100)
def apply_mask_to_step_output(self, ctx: DenoiseContext):
timestep = ctx.scheduler.timesteps[-1]
@ -111,8 +111,7 @@ class InpaintExt(ExtensionBase):
else:
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
# TODO: should here be used order?
# restore unmasked part after the last step is completed
# Restore unmasked part after the last step is completed
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask:

View File

@ -68,8 +68,7 @@ class InpaintModelExt(ExtensionBase):
self._masked_latents = torch.zeros_like(ctx.latents[:1])
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
# TODO: any ideas about order value?
# do last so that other extensions works with normal latents
# Use negative order to make extensions with default order work with patched latents
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
def append_inpaint_layers(self, ctx: DenoiseContext):
batch_size = ctx.unet_kwargs.sample.shape[0]
@ -80,8 +79,7 @@ class InpaintModelExt(ExtensionBase):
dim=1,
)
# TODO: should here be used order?
# restore unmasked part as inpaint model can change unmasked part slightly
# Restore unmasked part as inpaint model can change unmasked part slightly
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask: