Add FreeU support to denoise

This commit is contained in:
Sergey Borisov
2024-07-21 18:31:10 +03:00
parent f9c61f1b6c
commit e046e60e1c
3 changed files with 56 additions and 5 deletions

View File

@ -0,0 +1,42 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: Optional[FreeUConfig],
):
super().__init__()
self.freeu_config = freeu_config
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
if self.freeu_config is not None:
unet.enable_freeu(
b1=self.freeu_config.b1,
b2=self.freeu_config.b2,
s1=self.freeu_config.s1,
s2=self.freeu_config.s2,
)
did_apply_freeu = True
yield
finally:
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
if did_apply_freeu:
unet.disable_freeu()

View File

@ -63,9 +63,13 @@ class ExtensionsManager:
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create logic in PR with extension which uses it
yield None
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
yield None