InvokeAI/invokeai/backend/stable_diffusion/extensions/freeu.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

36 lines
886 B
Python
Raw Normal View History

2024-07-21 15:31:10 +00:00
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict
2024-07-21 15:31:10 +00:00
2024-07-21 15:37:20 +00:00
import torch
2024-07-21 15:31:10 +00:00
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
2024-07-21 15:31:10 +00:00
):
super().__init__()
self._freeu_config = freeu_config
2024-07-21 15:31:10 +00:00
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
2024-07-21 15:31:10 +00:00
try:
yield
2024-07-21 15:31:10 +00:00
finally:
unet.disable_freeu()