mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Minor fixes
This commit is contained in:
parent
137202b77c
commit
79e35bd0d3
@ -723,90 +723,88 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
# TODO: remove supression when extensions which use models added
|
ext_manager = ExtensionsManager()
|
||||||
with ExitStack() as exit_stack: # noqa: F841
|
|
||||||
ext_manager = ExtensionsManager()
|
|
||||||
|
|
||||||
device = TorchDevice.choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
noise = noise.to(device=device, dtype=dtype)
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
context=context,
|
context=context,
|
||||||
positive_conditioning_field=self.positive_conditioning,
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
negative_conditioning_field=self.negative_conditioning,
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
cfg_scale=self.cfg_scale,
|
cfg_scale=self.cfg_scale,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
# TODO: old backend, remove
|
# TODO: old backend, remove
|
||||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
|
scheduler,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=self.denoising_start,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
scheduler,
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
seed=seed,
|
|
||||||
device=device,
|
|
||||||
steps=self.steps,
|
|
||||||
denoising_start=self.denoising_start,
|
|
||||||
denoising_end=self.denoising_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
denoise_ctx = DenoiseContext(
|
### preview
|
||||||
inputs=DenoiseInputs(
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
orig_latents=latents,
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
attention_processor_cls=CustomAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
unet=None,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
### preview
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
|
||||||
context.util.sd_step_callback(state, unet_config.base)
|
|
||||||
|
|
||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
# ext: t2i/ip adapter
|
||||||
|
ext_manager.callbacks.setup(denoise_ctx, ext_manager)
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
unet_info = context.models.load(self.unet.unet)
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
with (
|
||||||
# ext: t2i/ip adapter
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
ext_manager.callbacks.setup(denoise_ctx, ext_manager)
|
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||||
|
# ext: controlnet
|
||||||
unet_info = context.models.load(self.unet.unet)
|
ext_manager.patch_extensions(unet),
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
with (
|
ext_manager.patch_unet(model_state_dict, unet),
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
):
|
||||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||||
# ext: controlnet
|
denoise_ctx.unet = unet
|
||||||
ext_manager.patch_extensions(unet),
|
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||||
# ext: freeu, seamless, ip adapter, lora
|
|
||||||
ext_manager.patch_unet(model_state_dict, unet),
|
|
||||||
):
|
|
||||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
|
||||||
denoise_ctx.unet = unet
|
|
||||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.detach().to("cpu")
|
result_latents = result_latents.detach().to("cpu")
|
||||||
|
@ -43,11 +43,11 @@ class ModelPatcher:
|
|||||||
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||||
"""
|
"""
|
||||||
unet_orig_processors = unet.attn_processors
|
unet_orig_processors = unet.attn_processors
|
||||||
try:
|
|
||||||
# create separate instance for each attention, to be able modify each attention separately
|
|
||||||
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
|
||||||
unet.set_attn_processor(new_attn_processors)
|
|
||||||
|
|
||||||
|
# create separate instance for each attention, to be able modify each attention separately
|
||||||
|
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||||
|
try:
|
||||||
|
unet.set_attn_processor(unet_new_processors)
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Callable, Dict
|
|||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
||||||
|
Loading…
Reference in New Issue
Block a user