diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4630d4740d..561624609b 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -61,18 +61,19 @@ class StableDiffusionBackend: def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) + # TODO: conditionings as list if self.sequential_guidance: - conditioning_call = self._apply_standard_conditioning_sequentially + ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") + ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") else: - conditioning_call = self._apply_standard_conditioning - - # not sure if here needed override - ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager) + both_noise_pred = self.run_unet(ctx, ext_manager, "both") + ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override apply_cfg ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx) # ext: cfg_rescale [modify_noise_prediction] + # TODO: rename ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 @@ -95,15 +96,13 @@ class StableDiffusionBackend: return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) - def _apply_standard_conditioning( - self, ctx: DenoiseContext, ext_manager: ExtensionsManager - ) -> tuple[torch.Tensor, torch.Tensor]: - """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at - the cost of higher memory usage. - """ + def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str): + sample = ctx.latent_model_input + if conditioning_mode == "both": + sample = torch.cat([sample] * 2) ctx.unet_kwargs = UNetKwargs( - sample=torch.cat([ctx.latent_model_input] * 2), + sample=sample, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning cross_attention_kwargs=dict( # noqa: C408 @@ -111,7 +110,7 @@ class StableDiffusionBackend: ), ) - ctx.conditioning_mode = "both" + ctx.conditioning_mode = conditioning_mode ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet_forward] @@ -120,75 +119,12 @@ class StableDiffusionBackend: # ext: inpaint [pre_unet_forward, priority=low] # or # ext: inpaint [override: unet_forward] - both_results = self._unet_forward(**vars(ctx.unet_kwargs)) - negative_next_x, positive_next_x = both_results.chunk(2) - # del locals - del ctx.unet_kwargs - del ctx.conditioning_mode - return negative_next_x, positive_next_x - - def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of - slower execution speed. - """ - - ################### - # Negative pass - ################### - - ctx.unet_kwargs = UNetKwargs( - sample=ctx.latent_model_input, - timestep=ctx.timestep, - encoder_hidden_states=None, # set later by conditoning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, - ), - ) - - ctx.conditioning_mode = "negative" - ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative") - - # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) - - # ext: inpaint [pre_unet_forward, priority=low] - # or - # ext: inpaint [override: unet_forward] - negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) + noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) del ctx.unet_kwargs del ctx.conditioning_mode - # TODO: gc.collect() ? - ################### - # Positive pass - ################### - - ctx.unet_kwargs = UNetKwargs( - sample=ctx.latent_model_input, - timestep=ctx.timestep, - encoder_hidden_states=None, # set later by conditoning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, - ), - ) - - ctx.conditioning_mode = "positive" - ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive") - - # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) - - # ext: inpaint [pre_unet_forward, priority=low] - # or - # ext: inpaint [override: unet_forward] - positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) - - del ctx.unet_kwargs - del ctx.conditioning_mode - # TODO: gc.collect() ? - - return negative_next_x, positive_next_x + return noise_pred def _unet_forward(self, **kwargs) -> torch.Tensor: return self.unet(**kwargs).sample