InvokeAI/invokeai/backend/stable_diffusion/extensions/freeu.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

36 lines
952 B
Python
Raw Normal View History

2024-07-21 15:31:10 +00:00
from __future__ import annotations
from contextlib import contextmanager
2024-07-30 00:39:01 +00:00
from typing import TYPE_CHECKING
2024-07-21 15:31:10 +00:00
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
2024-07-30 00:39:01 +00:00
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
2024-07-21 15:31:10 +00:00
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
2024-07-21 15:31:10 +00:00
):
super().__init__()
self._freeu_config = freeu_config
2024-07-21 15:31:10 +00:00
@contextmanager
2024-07-30 00:39:01 +00:00
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
2024-07-21 15:31:10 +00:00
try:
yield
2024-07-21 15:31:10 +00:00
finally:
unet.disable_freeu()