From 15b33ad50180eef9a155ff03e36767e9bffcfc07 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 11 Oct 2023 13:49:28 +1100 Subject: [PATCH] feat(nodes): add freeu support Add support for FreeU. See: - https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu - https://github.com/ChenyangSi/FreeU Implementation: - `ModelPatcher.apply_freeu()` handles the enabling freeu (which is very simple with diffusers). - `FreeUConfig` model added to hold the hyperparameters. - `freeu_config` added as optional sub-field on `UNetField`. - `FreeUInvocation` added, works like LoRA - chain it to add the FreeU config to the UNet - No support for model-dependent presets, this will be a future workflow editor enhancement Closes #4845 --- invokeai/app/invocations/baseinvocation.py | 4 ++ invokeai/app/invocations/latent.py | 1 + invokeai/app/invocations/model.py | 62 +++++++++++++++++----- invokeai/app/invocations/shared.py | 16 ++++++ invokeai/backend/model_management/lora.py | 21 ++++++++ 5 files changed, 91 insertions(+), 13 deletions(-) create mode 100644 invokeai/app/invocations/shared.py diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 497dafa102..71af414f5b 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -95,6 +95,10 @@ class FieldDescriptions: inclusive_low = "The inclusive low value" exclusive_high = "The exclusive high value" decimal_places = "The number of decimal places to round to" + freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." + freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." class Input(str, Enum): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c6bf37bdbc..077f6135da 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -655,6 +655,7 @@ class DenoiseLatentsInvocation(BaseInvocation): with ( ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), + ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet, ): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 571cb2e730..625d848bce 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -3,6 +3,8 @@ from typing import List, Optional from pydantic import BaseModel, Field +from invokeai.app.invocations.shared import FreeUConfig + from ...backend.model_management import BaseModelType, ModelType, SubModelType from .baseinvocation import ( BaseInvocation, @@ -34,6 +36,7 @@ class UNetField(BaseModel): scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') + freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") class ClipField(BaseModel): @@ -49,15 +52,34 @@ class VaeField(BaseModel): seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') -@invocation_output("model_loader_output") -class ModelLoaderOutput(BaseInvocationOutput): - """Model loader output""" +@invocation_output("unet_output") +class UNetOutput(BaseInvocationOutput): + """Base class for invocations that output a UNet field""" unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") - clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP") + + +@invocation_output("vae_output") +class VAEOutput(BaseInvocationOutput): + """Base class for invocations that output a VAE field""" + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") +@invocation_output("clip_output") +class CLIPOutput(BaseInvocationOutput): + """Base class for invocations that output a CLIP field""" + + clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP") + + +@invocation_output("model_loader_output") +class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): + """Model loader output""" + + pass + + class MainModelField(BaseModel): """Main model field""" @@ -331,13 +353,6 @@ class VAEModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") -@invocation_output("vae_loader_output") -class VaeLoaderOutput(BaseInvocationOutput): - """VAE output""" - - vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") - - @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -346,7 +361,7 @@ class VaeLoaderInvocation(BaseInvocation): description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE" ) - def invoke(self, context: InvocationContext) -> VaeLoaderOutput: + def invoke(self, context: InvocationContext) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae @@ -357,7 +372,7 @@ class VaeLoaderInvocation(BaseInvocation): model_type=model_type, ): raise Exception(f"Unkown vae name: {model_name}!") - return VaeLoaderOutput( + return VAEOutput( vae=VaeField( vae=ModelInfo( model_name=model_name, @@ -407,3 +422,24 @@ class SeamlessModeInvocation(BaseInvocation): vae.seamless_axes = seamless_axes_list return SeamlessModeOutput(unet=unet, vae=vae) + + +@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0") +class FreeUInvocation(BaseInvocation): + """ + Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2): + + SD1.5: 1.2/1.4/0.9/0.2, + SD2: 1.1/1.2/0.9/0.2, + SDXL: 1.1/1.2/0.6/0.4, + """ + + unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet") + b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1) + b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2) + s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) + s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) + + def invoke(self, context: InvocationContext) -> UNetOutput: + self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) + return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/shared.py b/invokeai/app/invocations/shared.py new file mode 100644 index 0000000000..db742a3433 --- /dev/null +++ b/invokeai/app/invocations/shared.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel, Field + +from invokeai.app.invocations.baseinvocation import FieldDescriptions + + +class FreeUConfig(BaseModel): + """ + Configuration for the FreeU hyperparameters. + - https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu + - https://github.com/ChenyangSi/FreeU + """ + + s1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s1) + s2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s2) + b1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b1) + b2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b2) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index bb44455c88..59aeef19ce 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -12,6 +12,8 @@ from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTokenizer +from invokeai.app.invocations.shared import FreeUConfig + from .models.lora import LoRAModel """ @@ -231,6 +233,25 @@ class ModelPatcher: while len(skipped_layers) > 0: text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) + @classmethod + @contextmanager + def apply_freeu( + cls, + unet: UNet2DConditionModel, + freeu_config: Optional[FreeUConfig] = None, + ): + did_apply_freeu = False + try: + if freeu_config is not None: + unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) + did_apply_freeu = True + + yield + + finally: + if did_apply_freeu: + unet.disable_freeu() + class TextualInversionModel: embedding: torch.Tensor # [n, 768]|[n, 1280]