InvokeAI/invokeai/backend/stable_diffusion/extensions/freeu.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

38 lines
991 B
Python
Raw Normal View History

2024-07-21 15:31:10 +00:00
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
2024-07-21 15:31:10 +00:00
2024-07-21 15:37:20 +00:00
import torch
2024-07-21 15:31:10 +00:00
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
2024-07-21 15:31:10 +00:00
):
super().__init__()
self._freeu_config = freeu_config
2024-07-21 15:31:10 +00:00
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
2024-07-21 15:31:10 +00:00
try:
yield set(), {}
2024-07-21 15:31:10 +00:00
finally:
unet.disable_freeu()