diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 6005bc83e0..17a79cca90 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -723,90 +723,88 @@ class DenoiseLatentsInvocation(BaseInvocation): @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def _new_invoke(self, context: InvocationContext) -> LatentsOutput: - # TODO: remove supression when extensions which use models added - with ExitStack() as exit_stack: # noqa: F841 - ext_manager = ExtensionsManager() + ext_manager = ExtensionsManager() - device = TorchDevice.choose_torch_device() - dtype = TorchDevice.choose_torch_dtype() + device = TorchDevice.choose_torch_device() + dtype = TorchDevice.choose_torch_dtype() - seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) - latents = latents.to(device=device, dtype=dtype) - if noise is not None: - noise = noise.to(device=device, dtype=dtype) + seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) - _, _, latent_height, latent_width = latents.shape + _, _, latent_height, latent_width = latents.shape - conditioning_data = self.get_conditioning_data( - context=context, - positive_conditioning_field=self.positive_conditioning, - negative_conditioning_field=self.negative_conditioning, - cfg_scale=self.cfg_scale, - steps=self.steps, - latent_height=latent_height, - latent_width=latent_width, - device=device, - dtype=dtype, - # TODO: old backend, remove - cfg_rescale_multiplier=self.cfg_rescale_multiplier, - ) + conditioning_data = self.get_conditioning_data( + context=context, + positive_conditioning_field=self.positive_conditioning, + negative_conditioning_field=self.negative_conditioning, + cfg_scale=self.cfg_scale, + steps=self.steps, + latent_height=latent_height, + latent_width=latent_width, + device=device, + dtype=dtype, + # TODO: old backend, remove + cfg_rescale_multiplier=self.cfg_rescale_multiplier, + ) - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, + ) + + timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( + scheduler, + seed=seed, + device=device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + ) + + denoise_ctx = DenoiseContext( + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, seed=seed, - ) + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, + ), + unet=None, + scheduler=scheduler, + ) - timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( - scheduler, - seed=seed, - device=device, - steps=self.steps, - denoising_start=self.denoising_start, - denoising_end=self.denoising_end, - ) + # get the unet's config so that we can pass the base to sd_step_callback() + unet_config = context.models.get_config(self.unet.unet.key) - denoise_ctx = DenoiseContext( - inputs=DenoiseInputs( - orig_latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, - ), - unet=None, - scheduler=scheduler, - ) + ### preview + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) - ### preview - def step_callback(state: PipelineIntermediateState) -> None: - context.util.sd_step_callback(state, unet_config.base) + ext_manager.add_extension(PreviewExt(step_callback)) - ext_manager.add_extension(PreviewExt(step_callback)) + # ext: t2i/ip adapter + ext_manager.callbacks.setup(denoise_ctx, ext_manager) - # get the unet's config so that we can pass the base to sd_step_callback() - unet_config = context.models.get_config(self.unet.unet.key) - - # ext: t2i/ip adapter - ext_manager.callbacks.setup(denoise_ctx, ext_manager) - - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - unet_info.model_on_device() as (model_state_dict, unet), - ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), - # ext: controlnet - ext_manager.patch_extensions(unet), - # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), - ): - sd_backend = StableDiffusionBackend(unet, scheduler) - denoise_ctx.unet = unet - result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), + # ext: controlnet + ext_manager.patch_extensions(unet), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.detach().to("cpu") diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index d31cb6bdef..b2d6036f63 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -43,11 +43,11 @@ class ModelPatcher: processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...). """ unet_orig_processors = unet.attn_processors - try: - # create separate instance for each attention, to be able modify each attention separately - new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()} - unet.set_attn_processor(new_attn_processors) + # create separate instance for each attention, to be able modify each attention separately + unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()} + try: + unet.set_attn_processor(unet_new_processors) yield None finally: diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 08004339e9..213eb5d782 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Callable, Dict import torch from diffusers import UNet2DConditionModel -from invokeai.backend.util.devices import TorchDevice - if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extensions import ExtensionBase