mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Revert "feat(nodes): add freeu support"
This commit is contained in:
parent
9733cd4199
commit
2e5cca5401
@ -92,10 +92,6 @@ class FieldDescriptions:
|
|||||||
inclusive_low = "The inclusive low value"
|
inclusive_low = "The inclusive low value"
|
||||||
exclusive_high = "The exclusive high value"
|
exclusive_high = "The exclusive high value"
|
||||||
decimal_places = "The number of decimal places to round to"
|
decimal_places = "The number of decimal places to round to"
|
||||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
|
||||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
|
||||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
|
||||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
|
||||||
|
|
||||||
|
|
||||||
class Input(str, Enum):
|
class Input(str, Enum):
|
||||||
|
@ -710,8 +710,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
|
||||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
|
||||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
|
@ -3,8 +3,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.shared import FreeUConfig
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -38,7 +36,6 @@ class UNetField(BaseModel):
|
|||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
@ -54,32 +51,13 @@ class VaeField(BaseModel):
|
|||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("unet_output")
|
|
||||||
class UNetOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a UNet field"""
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("vae_output")
|
|
||||||
class VAEOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a VAE field"""
|
|
||||||
|
|
||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_output")
|
|
||||||
class CLIPOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a CLIP field"""
|
|
||||||
|
|
||||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("model_loader_output")
|
@invocation_output("model_loader_output")
|
||||||
class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
pass
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
@ -388,6 +366,13 @@ class VAEModelField(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("vae_loader_output")
|
||||||
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""VAE output"""
|
||||||
|
|
||||||
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
@ -399,7 +384,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
title="VAE",
|
title="VAE",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
base_model = self.vae_model.base_model
|
base_model = self.vae_model.base_model
|
||||||
model_name = self.vae_model.model_name
|
model_name = self.vae_model.model_name
|
||||||
model_type = ModelType.Vae
|
model_type = ModelType.Vae
|
||||||
@ -410,7 +395,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown vae name: {model_name}!")
|
raise Exception(f"Unkown vae name: {model_name}!")
|
||||||
return VAEOutput(
|
return VaeLoaderOutput(
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -472,24 +457,3 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
vae.seamless_axes = seamless_axes_list
|
vae.seamless_axes = seamless_axes_list
|
||||||
|
|
||||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||||
|
|
||||||
|
|
||||||
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
|
|
||||||
class FreeUInvocation(BaseInvocation):
|
|
||||||
"""
|
|
||||||
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
|
|
||||||
|
|
||||||
SD1.5: 1.2/1.4/0.9/0.2,
|
|
||||||
SD2: 1.1/1.2/0.9/0.2,
|
|
||||||
SDXL: 1.1/1.2/0.6/0.4,
|
|
||||||
"""
|
|
||||||
|
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
|
|
||||||
b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1)
|
|
||||||
b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2)
|
|
||||||
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
|
||||||
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> UNetOutput:
|
|
||||||
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
|
||||||
return UNetOutput(unet=self.unet)
|
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import FieldDescriptions
|
|
||||||
|
|
||||||
|
|
||||||
class FreeUConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Configuration for the FreeU hyperparameters.
|
|
||||||
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
|
|
||||||
- https://github.com/ChenyangSi/FreeU
|
|
||||||
"""
|
|
||||||
|
|
||||||
s1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
|
||||||
s2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
|
||||||
b1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b1)
|
|
||||||
b2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b2)
|
|
@ -12,8 +12,6 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.shared import FreeUConfig
|
|
||||||
|
|
||||||
from .models.lora import LoRAModel
|
from .models.lora import LoRAModel
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -242,25 +240,6 @@ class ModelPatcher:
|
|||||||
while len(skipped_layers) > 0:
|
while len(skipped_layers) > 0:
|
||||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_freeu(
|
|
||||||
cls,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
freeu_config: Optional[FreeUConfig] = None,
|
|
||||||
):
|
|
||||||
did_apply_freeu = False
|
|
||||||
try:
|
|
||||||
if freeu_config is not None:
|
|
||||||
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
|
|
||||||
did_apply_freeu = True
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if did_apply_freeu:
|
|
||||||
unet.disable_freeu()
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
Loading…
Reference in New Issue
Block a user