mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add rescale cfg support to denoise
This commit is contained in:
parent
f9c61f1b6c
commit
9a1420280e
@ -59,6 +59,7 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import Custom
|
|||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
@ -790,6 +791,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
|
### cfg rescale
|
||||||
|
if self.cfg_rescale_multiplier > 0:
|
||||||
|
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
|
|
||||||
|
@ -76,12 +76,12 @@ class StableDiffusionBackend:
|
|||||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||||
|
|
||||||
# ext: override apply_cfg
|
# ext: override combine_noise_preds
|
||||||
ctx.noise_pred = self.apply_cfg(ctx)
|
ctx.noise_pred = self.combine_noise_preds(ctx)
|
||||||
|
|
||||||
# ext: cfg_rescale [modify_noise_prediction]
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
# TODO: rename
|
# TODO: rename
|
||||||
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||||
@ -95,7 +95,7 @@ class StableDiffusionBackend:
|
|||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
|
||||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = guidance_scale[ctx.step_index]
|
guidance_scale = guidance_scale[ctx.step_index]
|
||||||
|
@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
|
|||||||
POST_STEP = "post_step"
|
POST_STEP = "post_step"
|
||||||
PRE_UNET = "pre_unet"
|
PRE_UNET = "pre_unet"
|
||||||
POST_UNET = "post_unet"
|
POST_UNET = "post_unet"
|
||||||
POST_APPLY_CFG = "post_apply_cfg"
|
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
|
||||||
|
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class RescaleCFGExt(ExtensionBase):
|
||||||
|
def __init__(self, rescale_multiplier: float):
|
||||||
|
super().__init__()
|
||||||
|
self.rescale_multiplier = rescale_multiplier
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
||||||
|
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||||
|
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||||
|
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||||
|
|
||||||
|
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||||
|
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||||
|
return x_final
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
||||||
|
def rescale_noise_pred(self, ctx: DenoiseContext):
|
||||||
|
if self.rescale_multiplier > 0:
|
||||||
|
ctx.noise_pred = self._rescale_cfg(
|
||||||
|
ctx.noise_pred,
|
||||||
|
ctx.positive_noise_pred,
|
||||||
|
self.rescale_multiplier,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user