mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add FreeU support to denoise
This commit is contained in:
parent
f9c61f1b6c
commit
e046e60e1c
@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
@ -790,18 +791,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
|
### freeu
|
||||||
|
if self.unet.freeu_config:
|
||||||
|
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (cached_weights, unet),
|
||||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||||
# ext: controlnet
|
# ext: controlnet
|
||||||
ext_manager.patch_extensions(unet),
|
ext_manager.patch_extensions(unet),
|
||||||
# ext: freeu, seamless, ip adapter, lora
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
ext_manager.patch_unet(model_state_dict, unet),
|
ext_manager.patch_unet(unet, cached_weights),
|
||||||
):
|
):
|
||||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||||
denoise_ctx.unet = unet
|
denoise_ctx.unet = unet
|
||||||
|
42
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
42
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Dict, Optional
|
||||||
|
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
|
||||||
|
|
||||||
|
class FreeUExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
freeu_config: Optional[FreeUConfig],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.freeu_config = freeu_config
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
did_apply_freeu = False
|
||||||
|
try:
|
||||||
|
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
|
||||||
|
if self.freeu_config is not None:
|
||||||
|
unet.enable_freeu(
|
||||||
|
b1=self.freeu_config.b1,
|
||||||
|
b2=self.freeu_config.b2,
|
||||||
|
s1=self.freeu_config.s1,
|
||||||
|
s2=self.freeu_config.s2,
|
||||||
|
)
|
||||||
|
did_apply_freeu = True
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
|
||||||
|
if did_apply_freeu:
|
||||||
|
unet.disable_freeu()
|
@ -63,9 +63,13 @@ class ExtensionsManager:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# TODO: create logic in PR with extension which uses it
|
# TODO: create weight patch logic in PR with extension which uses it
|
||||||
yield None
|
with ExitStack() as exit_stack:
|
||||||
|
for ext in self._extensions:
|
||||||
|
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
||||||
|
|
||||||
|
yield None
|
||||||
|
Loading…
Reference in New Issue
Block a user