from __future__ import annotations from contextlib import contextmanager from typing import TYPE_CHECKING from diffusers import UNet2DConditionModel from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: from invokeai.app.shared.models import FreeUConfig from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage class FreeUExt(ExtensionBase): def __init__( self, freeu_config: FreeUConfig, ): super().__init__() self._freeu_config = freeu_config @contextmanager def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, s1=self._freeu_config.s1, s2=self._freeu_config.s2, ) try: yield finally: unet.disable_freeu()