Modular backend - add FreeU (#6641)

## Summary

FreeU code from https://github.com/invoke-ai/InvokeAI/pull/6577.
Also fix issue with sometimes slightly different output.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-23 10:02:56 -04:00 committed by GitHub
commit 7b8e25f525
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 55 additions and 9 deletions

View File

@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@ -795,18 +796,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.cfg_rescale_multiplier > 0: if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier)) ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
# ext: t2i/ip adapter # ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet) unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
unet_info.model_on_device() as (model_state_dict, unet), unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet # ext: controlnet
ext_manager.patch_extensions(unet), ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora # ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet), ext_manager.patch_unet(unet, cached_weights),
): ):
sd_backend = StableDiffusionBackend(unet, scheduler) sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet denoise_ctx.unet = unet

View File

@ -100,8 +100,10 @@ class StableDiffusionBackend:
if isinstance(guidance_scale, list): if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index] guidance_scale = guidance_scale[ctx.step_index]
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) # in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input sample = ctx.latent_model_input

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -56,5 +56,5 @@ class ExtensionBase:
yield None yield None
@contextmanager @contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None yield None

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
):
super().__init__()
self._freeu_config = freeu_config
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
try:
yield
finally:
unet.disable_freeu()

View File

@ -63,9 +63,13 @@ class ExtensionsManager:
yield None yield None
@contextmanager @contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled(): if self._is_canceled and self._is_canceled():
raise CanceledException raise CanceledException
# TODO: create logic in PR with extension which uses it # TODO: create weight patch logic in PR with extension which uses it
yield None with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
yield None