Merge branch 'main' into stalker-modular_controlnet

This commit is contained in:
Ryan Dick
2024-07-23 10:19:12 -04:00
53 changed files with 1504 additions and 283 deletions

View File

@ -59,7 +59,9 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import Custom
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@ -823,6 +825,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
ext_manager.add_extension(PreviewExt(step_callback))
### cfg rescale
if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
@ -837,12 +847,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
ext_manager.patch_unet(unet, cached_weights),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet