2024-07-21 15:31:10 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from contextlib import contextmanager
|
2024-07-27 01:25:15 +00:00
|
|
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
2024-07-21 15:31:10 +00:00
|
|
|
|
2024-07-21 15:37:20 +00:00
|
|
|
import torch
|
2024-07-21 15:31:10 +00:00
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
|
|
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.app.shared.models import FreeUConfig
|
|
|
|
|
|
|
|
|
|
|
|
class FreeUExt(ExtensionBase):
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-07-22 20:09:11 +00:00
|
|
|
freeu_config: FreeUConfig,
|
2024-07-21 15:31:10 +00:00
|
|
|
):
|
|
|
|
super().__init__()
|
2024-07-22 20:09:11 +00:00
|
|
|
self._freeu_config = freeu_config
|
2024-07-21 15:31:10 +00:00
|
|
|
|
|
|
|
@contextmanager
|
2024-07-27 01:25:15 +00:00
|
|
|
def patch_unet(
|
|
|
|
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
|
|
|
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
2024-07-22 20:09:11 +00:00
|
|
|
unet.enable_freeu(
|
|
|
|
b1=self._freeu_config.b1,
|
|
|
|
b2=self._freeu_config.b2,
|
|
|
|
s1=self._freeu_config.s1,
|
|
|
|
s2=self._freeu_config.s2,
|
|
|
|
)
|
2024-07-21 15:31:10 +00:00
|
|
|
|
2024-07-22 20:09:11 +00:00
|
|
|
try:
|
2024-07-27 01:25:15 +00:00
|
|
|
yield set(), {}
|
2024-07-21 15:31:10 +00:00
|
|
|
finally:
|
2024-07-22 20:09:11 +00:00
|
|
|
unet.disable_freeu()
|