Minor fixes

This commit is contained in:
Sergey Borisov 2024-07-17 03:48:37 +03:00
parent 137202b77c
commit 79e35bd0d3
3 changed files with 75 additions and 79 deletions

View File

@ -723,8 +723,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
# TODO: remove supression when extensions which use models added
with ExitStack() as exit_stack: # noqa: F841
ext_manager = ExtensionsManager()
device = TorchDevice.choose_torch_device()
@ -782,15 +780,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
ext_manager.add_extension(PreviewExt(step_callback))
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
# ext: t2i/ip adapter
ext_manager.callbacks.setup(denoise_ctx, ext_manager)

View File

@ -43,11 +43,11 @@ class ModelPatcher:
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
try:
# create separate instance for each attention, to be able modify each attention separately
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
unet.set_attn_processor(new_attn_processors)
# create separate instance for each attention, to be able modify each attention separately
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
try:
unet.set_attn_processor(unet_new_processors)
yield None
finally:

View File

@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Callable, Dict
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions import ExtensionBase