mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into chainchompa/simple-upscale-updates
This commit is contained in:
commit
8107884c8d
@ -58,7 +58,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@ -790,18 +792,26 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
### cfg rescale
|
||||
if self.cfg_rescale_multiplier > 0:
|
||||
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
|
||||
|
||||
### freeu
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(unet),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(model_state_dict, unet),
|
||||
ext_manager.patch_unet(unet, cached_weights),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
|
@ -83,47 +83,47 @@ class DenoiseContext:
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
# Current state of latent-space image in denoising process.
|
||||
# None until `pre_denoise_loop` callback.
|
||||
# None until `PRE_DENOISE_LOOP` callback.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latents: Optional[torch.Tensor] = None
|
||||
|
||||
# Current denoising step index.
|
||||
# None until `pre_step` callback.
|
||||
# None until `PRE_STEP` callback.
|
||||
step_index: Optional[int] = None
|
||||
|
||||
# Current denoising step timestep.
|
||||
# None until `pre_step` callback.
|
||||
# None until `PRE_STEP` callback.
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
|
||||
# Arguments which will be passed to UNet model.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
conditioning_mode: Optional[ConditioningMode] = None
|
||||
|
||||
# [TMP] Noise predictions from negative conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Noise predictions from positive conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Combined noise prediction from passed conditionings.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
|
@ -76,12 +76,12 @@ class StableDiffusionBackend:
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = self.apply_cfg(ctx)
|
||||
# ext: override combine_noise_preds
|
||||
ctx.noise_pred = self.combine_noise_preds(ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
# TODO: rename
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||
@ -95,13 +95,15 @@ class StableDiffusionBackend:
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
|
||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
|
||||
# in slightly different outputs. It is suspected that this is caused by small precision differences.
|
||||
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||
sample = ctx.latent_model_input
|
||||
|
@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
|
||||
POST_STEP = "post_step"
|
||||
PRE_UNET = "pre_unet"
|
||||
POST_UNET = "post_unet"
|
||||
POST_APPLY_CFG = "post_apply_cfg"
|
||||
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -56,5 +56,5 @@ class ExtensionBase:
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
yield None
|
||||
|
35
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
35
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
|
||||
class FreeUExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
freeu_config: FreeUConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self._freeu_config = freeu_config
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
unet.enable_freeu(
|
||||
b1=self._freeu_config.b1,
|
||||
b2=self._freeu_config.b2,
|
||||
s1=self._freeu_config.s1,
|
||||
s2=self._freeu_config.s2,
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
unet.disable_freeu()
|
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class RescaleCFGExt(ExtensionBase):
|
||||
def __init__(self, rescale_multiplier: float):
|
||||
super().__init__()
|
||||
self._rescale_multiplier = rescale_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
||||
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||
return x_final
|
||||
|
||||
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
||||
def rescale_noise_pred(self, ctx: DenoiseContext):
|
||||
if self._rescale_multiplier > 0:
|
||||
ctx.noise_pred = self._rescale_cfg(
|
||||
ctx.noise_pred,
|
||||
ctx.positive_noise_pred,
|
||||
self._rescale_multiplier,
|
||||
)
|
@ -63,9 +63,13 @@ class ExtensionsManager:
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: create logic in PR with extension which uses it
|
||||
yield None
|
||||
# TODO: create weight patch logic in PR with extension which uses it
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
||||
|
||||
yield None
|
||||
|
Loading…
Reference in New Issue
Block a user