From 5f0fe3c8a986cc32ec20a40c3df56145bb0222ab Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 22 Jul 2024 23:09:11 +0300 Subject: [PATCH] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- .../stable_diffusion/diffusion_backend.py | 4 +-- .../stable_diffusion/extensions/base.py | 4 +-- .../stable_diffusion/extensions/freeu.py | 27 +++++++------------ 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 60a21bdc02..5d0a68513f 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -100,8 +100,8 @@ class StableDiffusionBackend: if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] - # Note: Although logically it same, it seams that precision errors differs. - # This sometimes results in slightly different output. + # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result + # in slightly different outputs. It is suspected that this is caused by small precision differences. # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 802af86e6d..6a85a2e441 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -56,5 +56,5 @@ class ExtensionBase: yield None @contextmanager - def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): yield None diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 0f6c47a773..6ec4fea3fa 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -15,28 +15,21 @@ if TYPE_CHECKING: class FreeUExt(ExtensionBase): def __init__( self, - freeu_config: Optional[FreeUConfig], + freeu_config: FreeUConfig, ): super().__init__() - self.freeu_config = freeu_config + self._freeu_config = freeu_config @contextmanager def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - did_apply_freeu = False + unet.enable_freeu( + b1=self._freeu_config.b1, + b2=self._freeu_config.b2, + s1=self._freeu_config.s1, + s2=self._freeu_config.s2, + ) + try: - assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? - if self.freeu_config is not None: - unet.enable_freeu( - b1=self.freeu_config.b1, - b2=self.freeu_config.b2, - s1=self.freeu_config.s1, - s2=self.freeu_config.s2, - ) - did_apply_freeu = True - yield - finally: - assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? - if did_apply_freeu: - unet.disable_freeu() + unet.disable_freeu()