Minor fixes

This commit is contained in:
Sergey Borisov 2024-07-17 03:48:37 +03:00
parent 137202b77c
commit 79e35bd0d3
3 changed files with 75 additions and 79 deletions

View File

@ -723,90 +723,88 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers. @SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _new_invoke(self, context: InvocationContext) -> LatentsOutput: def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
# TODO: remove supression when extensions which use models added ext_manager = ExtensionsManager()
with ExitStack() as exit_stack: # noqa: F841
ext_manager = ExtensionsManager()
device = TorchDevice.choose_torch_device() device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype() dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype) latents = latents.to(device=device, dtype=dtype)
if noise is not None: if noise is not None:
noise = noise.to(device=device, dtype=dtype) noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape _, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data( conditioning_data = self.get_conditioning_data(
context=context, context=context,
positive_conditioning_field=self.positive_conditioning, positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning, negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale, cfg_scale=self.cfg_scale,
steps=self.steps, steps=self.steps,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
device=device, device=device,
dtype=dtype, dtype=dtype,
# TODO: old backend, remove # TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier, cfg_rescale_multiplier=self.cfg_rescale_multiplier,
) )
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler, scheduler_name=self.scheduler,
seed=seed,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed, seed=seed,
) scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( # get the unet's config so that we can pass the base to sd_step_callback()
scheduler, unet_config = context.models.get_config(self.unet.unet.key)
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext( ### preview
inputs=DenoiseInputs( def step_callback(state: PipelineIntermediateState) -> None:
orig_latents=latents, context.util.sd_step_callback(state, unet_config.base)
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
### preview ext_manager.add_extension(PreviewExt(step_callback))
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
ext_manager.add_extension(PreviewExt(step_callback)) # ext: t2i/ip adapter
ext_manager.callbacks.setup(denoise_ctx, ext_manager)
# get the unet's config so that we can pass the base to sd_step_callback() unet_info = context.models.load(self.unet.unet)
unet_config = context.models.get_config(self.unet.unet.key) assert isinstance(unet_info.model, UNet2DConditionModel)
with (
# ext: t2i/ip adapter unet_info.model_on_device() as (model_state_dict, unet),
ext_manager.callbacks.setup(denoise_ctx, ext_manager) ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
unet_info = context.models.load(self.unet.unet) ext_manager.patch_extensions(unet),
assert isinstance(unet_info.model, UNet2DConditionModel) # ext: freeu, seamless, ip adapter, lora
with ( ext_manager.patch_unet(model_state_dict, unet),
unet_info.model_on_device() as (model_state_dict, unet), ):
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), sd_backend = StableDiffusionBackend(unet, scheduler)
# ext: controlnet denoise_ctx.unet = unet
ext_manager.patch_extensions(unet), result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu") result_latents = result_latents.detach().to("cpu")

View File

@ -43,11 +43,11 @@ class ModelPatcher:
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...). processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
""" """
unet_orig_processors = unet.attn_processors unet_orig_processors = unet.attn_processors
try:
# create separate instance for each attention, to be able modify each attention separately
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
unet.set_attn_processor(new_attn_processors)
# create separate instance for each attention, to be able modify each attention separately
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
try:
unet.set_attn_processor(unet_new_processors)
yield None yield None
finally: finally:

View File

@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Callable, Dict
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions import ExtensionBase from invokeai.backend.stable_diffusion.extensions import ExtensionBase