mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from contextlib import contextmanager
|
||
|
from typing import TYPE_CHECKING, Dict, Optional
|
||
|
|
||
|
from diffusers import UNet2DConditionModel
|
||
|
|
||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from invokeai.app.shared.models import FreeUConfig
|
||
|
|
||
|
|
||
|
class FreeUExt(ExtensionBase):
|
||
|
def __init__(
|
||
|
self,
|
||
|
freeu_config: Optional[FreeUConfig],
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.freeu_config = freeu_config
|
||
|
|
||
|
@contextmanager
|
||
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||
|
did_apply_freeu = False
|
||
|
try:
|
||
|
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
|
||
|
if self.freeu_config is not None:
|
||
|
unet.enable_freeu(
|
||
|
b1=self.freeu_config.b1,
|
||
|
b2=self.freeu_config.b2,
|
||
|
s1=self.freeu_config.s1,
|
||
|
s2=self.freeu_config.s2,
|
||
|
)
|
||
|
did_apply_freeu = True
|
||
|
|
||
|
yield
|
||
|
|
||
|
finally:
|
||
|
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
|
||
|
if did_apply_freeu:
|
||
|
unet.disable_freeu()
|