mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
1748848b7b
commit
5f0fe3c8a9
@ -100,8 +100,8 @@ class StableDiffusionBackend:
|
|||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = guidance_scale[ctx.step_index]
|
guidance_scale = guidance_scale[ctx.step_index]
|
||||||
|
|
||||||
# Note: Although logically it same, it seams that precision errors differs.
|
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
|
||||||
# This sometimes results in slightly different output.
|
# in slightly different outputs. It is suspected that this is caused by small precision differences.
|
||||||
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||||
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -56,5 +56,5 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
yield None
|
yield None
|
||||||
|
@ -15,28 +15,21 @@ if TYPE_CHECKING:
|
|||||||
class FreeUExt(ExtensionBase):
|
class FreeUExt(ExtensionBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
freeu_config: Optional[FreeUConfig],
|
freeu_config: FreeUConfig,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
did_apply_freeu = False
|
|
||||||
try:
|
|
||||||
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
|
|
||||||
if self.freeu_config is not None:
|
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self.freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self.freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
s1=self.freeu_config.s1,
|
s1=self._freeu_config.s1,
|
||||||
s2=self.freeu_config.s2,
|
s2=self._freeu_config.s2,
|
||||||
)
|
)
|
||||||
did_apply_freeu = True
|
|
||||||
|
|
||||||
|
try:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
|
|
||||||
if did_apply_freeu:
|
|
||||||
unet.disable_freeu()
|
unet.disable_freeu()
|
||||||
|
Loading…
Reference in New Issue
Block a user