Revert "feat(nodes): add freeu support"

This commit is contained in:
Millun Atluri 2023-11-07 14:05:34 +11:00 committed by GitHub
parent 9733cd4199
commit 2e5cca5401
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 92 deletions

View File

@ -92,10 +92,6 @@ class FieldDescriptions:
inclusive_low = "The inclusive low value" inclusive_low = "The inclusive low value"
exclusive_high = "The exclusive high value" exclusive_high = "The exclusive high value"
decimal_places = "The number of decimal places to round to" decimal_places = "The number of decimal places to round to"
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
class Input(str, Enum): class Input(str, Enum):

View File

@ -710,8 +710,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes), set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet, unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.

View File

@ -3,8 +3,6 @@ from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.shared import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -38,7 +36,6 @@ class UNetField(BaseModel):
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel): class ClipField(BaseModel):
@ -54,32 +51,13 @@ class VaeField(BaseModel):
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@invocation_output("vae_output")
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output") @invocation_output("model_loader_output")
class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
pass unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
class MainModelField(BaseModel): class MainModelField(BaseModel):
@ -388,6 +366,13 @@ class VAEModelField(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@invocation_output("vae_loader_output")
class VaeLoaderOutput(BaseInvocationOutput):
"""VAE output"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
class VaeLoaderInvocation(BaseInvocation): class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
@ -399,7 +384,7 @@ class VaeLoaderInvocation(BaseInvocation):
title="VAE", title="VAE",
) )
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model base_model = self.vae_model.base_model
model_name = self.vae_model.model_name model_name = self.vae_model.model_name
model_type = ModelType.Vae model_type = ModelType.Vae
@ -410,7 +395,7 @@ class VaeLoaderInvocation(BaseInvocation):
model_type=model_type, model_type=model_type,
): ):
raise Exception(f"Unkown vae name: {model_name}!") raise Exception(f"Unkown vae name: {model_name}!")
return VAEOutput( return VaeLoaderOutput(
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=model_name, model_name=model_name,
@ -472,24 +457,3 @@ class SeamlessModeInvocation(BaseInvocation):
vae.seamless_axes = seamless_axes_list vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae) return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
class FreeUInvocation(BaseInvocation):
"""
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
SD1.5: 1.2/1.4/0.9/0.2,
SD2: 1.1/1.2/0.9/0.2,
SDXL: 1.1/1.2/0.6/0.4,
"""
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1)
b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2)
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
def invoke(self, context: InvocationContext) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)

View File

@ -1,16 +0,0 @@
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import FieldDescriptions
class FreeUConfig(BaseModel):
"""
Configuration for the FreeU hyperparameters.
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
- https://github.com/ChenyangSi/FreeU
"""
s1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s2)
b1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b1)
b2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b2)

View File

@ -12,8 +12,6 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from invokeai.app.invocations.shared import FreeUConfig
from .models.lora import LoRAModel from .models.lora import LoRAModel
""" """
@ -242,25 +240,6 @@ class ModelPatcher:
while len(skipped_layers) > 0: while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
@classmethod
@contextmanager
def apply_freeu(
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
):
did_apply_freeu = False
try:
if freeu_config is not None:
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
did_apply_freeu = True
yield
finally:
if did_apply_freeu:
unet.disable_freeu()
class TextualInversionModel: class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]