Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-22 23:09:11 +03:00
parent 1748848b7b
commit 5f0fe3c8a9
3 changed files with 14 additions and 21 deletions

View File

@ -100,8 +100,8 @@ class StableDiffusionBackend:
if isinstance(guidance_scale, list): if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index] guidance_scale = guidance_scale[ctx.step_index]
# Note: Although logically it same, it seams that precision errors differs. # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# This sometimes results in slightly different output. # in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -56,5 +56,5 @@ class ExtensionBase:
yield None yield None
@contextmanager @contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None yield None

View File

@ -15,28 +15,21 @@ if TYPE_CHECKING:
class FreeUExt(ExtensionBase): class FreeUExt(ExtensionBase):
def __init__( def __init__(
self, self,
freeu_config: Optional[FreeUConfig], freeu_config: FreeUConfig,
): ):
super().__init__() super().__init__()
self.freeu_config = freeu_config self._freeu_config = freeu_config
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
did_apply_freeu = False unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
try: try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
if self.freeu_config is not None:
unet.enable_freeu(
b1=self.freeu_config.b1,
b2=self.freeu_config.b2,
s1=self.freeu_config.s1,
s2=self.freeu_config.s2,
)
did_apply_freeu = True
yield yield
finally: finally:
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? unet.disable_freeu()
if did_apply_freeu:
unet.disable_freeu()