mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add freeu support
Add support for FreeU. See: - https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu - https://github.com/ChenyangSi/FreeU Implementation: - `ModelPatcher.apply_freeu()` handles the enabling freeu (which is very simple with diffusers). - `FreeUConfig` model added to hold the hyperparameters. - `freeu_config` added as optional sub-field on `UNetField`. - `FreeUInvocation` added, works like LoRA - chain it to add the FreeU config to the UNet - No support for model-dependent presets, this will be a future workflow editor enhancement Closes #4845
This commit is contained in:
@ -3,6 +3,8 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.shared import FreeUConfig
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -34,6 +36,7 @@ class UNetField(BaseModel):
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class ClipField(BaseModel):
|
||||
@ -49,15 +52,34 @@ class VaeField(BaseModel):
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a UNet field"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("vae_output")
|
||||
class VAEOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a VAE field"""
|
||||
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("clip_output")
|
||||
class CLIPOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a CLIP field"""
|
||||
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
@ -331,13 +353,6 @@ class VAEModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@invocation_output("vae_loader_output")
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
"""VAE output"""
|
||||
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
@ -346,7 +361,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
@ -357,7 +372,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
return VaeLoaderOutput(
|
||||
return VAEOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
@ -407,3 +422,24 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
vae.seamless_axes = seamless_axes_list
|
||||
|
||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||
|
||||
|
||||
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
|
||||
class FreeUInvocation(BaseInvocation):
|
||||
"""
|
||||
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
|
||||
|
||||
SD1.5: 1.2/1.4/0.9/0.2,
|
||||
SD2: 1.1/1.2/0.9/0.2,
|
||||
SDXL: 1.1/1.2/0.6/0.4,
|
||||
"""
|
||||
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
|
||||
b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1)
|
||||
b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2)
|
||||
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
||||
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> UNetOutput:
|
||||
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
||||
return UNetOutput(unet=self.unet)
|
||||
|
Reference in New Issue
Block a user