Merge branch 'main' into chainchompa/simple-upscale-updates

This commit is contained in:
chainchompa 2024-07-23 10:28:11 -04:00 committed by GitHub
commit 8107884c8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 111 additions and 24 deletions

View File

@ -58,7 +58,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@ -790,18 +792,26 @@ class DenoiseLatentsInvocation(BaseInvocation):
ext_manager.add_extension(PreviewExt(step_callback))
### cfg rescale
if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
ext_manager.patch_unet(unet, cached_weights),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet

View File

@ -83,47 +83,47 @@ class DenoiseContext:
unet: Optional[UNet2DConditionModel] = None
# Current state of latent-space image in denoising process.
# None until `pre_denoise_loop` callback.
# None until `PRE_DENOISE_LOOP` callback.
# Shape: [batch, channels, latent_height, latent_width]
latents: Optional[torch.Tensor] = None
# Current denoising step index.
# None until `pre_step` callback.
# None until `PRE_STEP` callback.
step_index: Optional[int] = None
# Current denoising step timestep.
# None until `pre_step` callback.
# None until `PRE_STEP` callback.
timestep: Optional[torch.Tensor] = None
# Arguments which will be passed to UNet model.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None
# SchedulerOutput class returned from step function(normally, generated by scheduler).
# Supposed to be used only in `post_step` callback, otherwise can be None.
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
step_output: Optional[SchedulerOutput] = None
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
# Available in events inside step(between `pre_step` and `post_stop`).
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
# Shape: [batch, channels, latent_height, latent_width]
latent_model_input: Optional[torch.Tensor] = None
# [TMP] Defines on which conditionings current unet call will be runned.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None
# [TMP] Noise predictions from negative conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
negative_noise_pred: Optional[torch.Tensor] = None
# [TMP] Noise predictions from positive conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: Optional[torch.Tensor] = None
# Combined noise prediction from passed conditionings.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
noise_pred: Optional[torch.Tensor] = None

View File

@ -76,12 +76,12 @@ class StableDiffusionBackend:
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
# ext: override apply_cfg
ctx.noise_pred = self.apply_cfg(ctx)
# ext: override combine_noise_preds
ctx.noise_pred = self.combine_noise_preds(ctx)
# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
@ -95,13 +95,15 @@ class StableDiffusionBackend:
return step_output
@staticmethod
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input

View File

@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
POST_STEP = "post_step"
PRE_UNET = "pre_unet"
POST_UNET = "post_unet"
POST_APPLY_CFG = "post_apply_cfg"
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@ -56,5 +56,5 @@ class ExtensionBase:
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
):
super().__init__()
self._freeu_config = freeu_config
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
try:
yield
finally:
unet.disable_freeu()

View File

@ -0,0 +1,36 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
class RescaleCFGExt(ExtensionBase):
def __init__(self, rescale_multiplier: float):
super().__init__()
self._rescale_multiplier = rescale_multiplier
@staticmethod
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
def rescale_noise_pred(self, ctx: DenoiseContext):
if self._rescale_multiplier > 0:
ctx.noise_pred = self._rescale_cfg(
ctx.noise_pred,
ctx.positive_noise_pred,
self._rescale_multiplier,
)

View File

@ -63,9 +63,13 @@ class ExtensionsManager:
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create logic in PR with extension which uses it
yield None
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
yield None