From 608cbe3f5c4f5efe2ed507dfc3a81d57eaaa0423 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 19:30:29 +0300 Subject: [PATCH] Separate inputs in denoise context --- invokeai/app/invocations/denoise_latents.py | 18 +++++++------- .../stable_diffusion/denoise_context.py | 11 ++++++--- .../stable_diffusion/diffusion_backend.py | 24 ++++++++++--------- .../stable_diffusion/extensions/preview.py | 4 ++-- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 7563c30223..81b92d4fa7 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -40,7 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, StableDiffusionGeneratorPipeline, @@ -768,13 +768,15 @@ class DenoiseLatentsInvocation(BaseInvocation): ) denoise_ctx = DenoiseContext( - latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + ), unet=None, scheduler=scheduler, ) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 453398a121..2a00052fd1 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -30,8 +30,8 @@ class UNetKwargs: @dataclass -class DenoiseContext: - latents: torch.Tensor +class DenoiseInputs: + orig_latents: torch.Tensor scheduler_step_kwargs: dict[str, Any] conditioning_data: TextConditioningData noise: Optional[torch.Tensor] @@ -39,10 +39,15 @@ class DenoiseContext: timesteps: torch.Tensor init_timestep: torch.Tensor + +@dataclass +class DenoiseContext: + inputs: DenoiseInputs + scheduler: SchedulerMixin unet: Optional[UNet2DConditionModel] = None - orig_latents: Optional[torch.Tensor] = None + latents: Optional[torch.Tensor] = None step_index: Optional[int] = None timestep: Optional[torch.Tensor] = None unet_kwargs: Optional[UNetKwargs] = None diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4c08639ddf..f8cb92d1d4 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -22,25 +22,27 @@ class StableDiffusionBackend: self.sequential_guidance = config.sequential_guidance def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - if ctx.init_timestep.shape[0] == 0: - return ctx.latents + if ctx.inputs.init_timestep.shape[0] == 0: + return ctx.inputs.orig_latents - ctx.orig_latents = ctx.latents.clone() + ctx.latents = ctx.inputs.orig_latents.clone() - if ctx.noise is not None: + if ctx.inputs.noise is not None: batch_size = ctx.latents.shape[0] # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers - ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size)) + ctx.latents = ctx.scheduler.add_noise( + ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size) + ) # if no work to do, return latents - if ctx.timesteps.shape[0] == 0: + if ctx.inputs.timesteps.shape[0] == 0: return ctx.latents # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: preview[pre_denoise_loop, priority=low] ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) - for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 + for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 # ext: inpaint (apply mask to latents on non-inpaint models) ext_manager.callbacks.pre_step(ctx, ext_manager) @@ -80,7 +82,7 @@ class StableDiffusionBackend: ext_manager.callbacks.post_apply_cfg(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 - step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) + step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs) # clean up locals ctx.latent_model_input = None @@ -92,7 +94,7 @@ class StableDiffusionBackend: @staticmethod def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: - guidance_scale = ctx.conditioning_data.guidance_scale + guidance_scale = ctx.inputs.conditioning_data.guidance_scale if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] @@ -109,12 +111,12 @@ class StableDiffusionBackend: timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + percent_through=ctx.step_index / len(ctx.inputs.timesteps), ), ) ctx.conditioning_mode = conditioning_mode - ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) + ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet] ext_manager.callbacks.pre_unet(ctx, ext_manager) diff --git a/invokeai/backend/stable_diffusion/extensions/preview.py b/invokeai/backend/stable_diffusion/extensions/preview.py index 73a1eef3c5..acc55e6172 100644 --- a/invokeai/backend/stable_diffusion/extensions/preview.py +++ b/invokeai/backend/stable_diffusion/extensions/preview.py @@ -35,7 +35,7 @@ class PreviewExt(ExtensionBase): PipelineIntermediateState( step=-1, order=ctx.scheduler.order, - total_steps=len(ctx.timesteps), + total_steps=len(ctx.inputs.timesteps), timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it? latents=ctx.latents, ) @@ -55,7 +55,7 @@ class PreviewExt(ExtensionBase): PipelineIntermediateState( step=ctx.step_index, order=ctx.scheduler.order, - total_steps=len(ctx.timesteps), + total_steps=len(ctx.inputs.timesteps), timestep=int(ctx.timestep), # TODO: is there any code which uses it? latents=ctx.step_output.prev_sample, predicted_original=predicted_original, # TODO: is there any reason for additional field?