fix(mm): add ui_type to model fields

Recently the schema for models was changed to a generic `ModelField`, and the UI was unable to derive the type of those fields. This didn't affect functionality, but it did break the styling of handles.

Add `ui_type` to the affected fields and update the UI to use the correct capitalizations.
This commit is contained in:
psychedelicious 2024-03-08 21:37:00 +11:00 committed by Mary Hipp Rogers
parent fe2c6f621a
commit ddde355b09
6 changed files with 25 additions and 16 deletions

View File

@ -31,6 +31,7 @@ from invokeai.app.invocations.fields import (
Input,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
@ -90,7 +91,9 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_model: ModelField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)

View File

@ -39,13 +39,15 @@ class UIType(str, Enum, metaclass=MetaEnum):
"""
# region Model Field Types
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VaeModel = "VAEModelField"
VAEModel = "VAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
# endregion
# region Misc Field Types
@ -86,7 +88,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
StringPolymorphic = "DEPRECATED_StringPolymorphic"
MainModel = "DEPRECATED_MainModel"
UNet = "DEPRECATED_UNet"
Vae = "DEPRECATED_Vae"
CLIP = "DEPRECATED_CLIP"

View File

@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -55,7 +55,11 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
weight: Union[float, List[float]] = InputField(

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
@ -84,7 +84,7 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@ -117,7 +117,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -186,7 +186,7 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -262,6 +262,7 @@ class VAELoaderInvocation(BaseInvocation):
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
ui_type=UIType.VAEModel
)
def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -57,6 +57,7 @@ class T2IAdapterInvocation(BaseInvocation):
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"

View File

@ -35,10 +35,9 @@ export const MODEL_TYPES = [
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'VaeField',
'ClipField',
'VAEField',
'CLIPField',
'T2IAdapterModelField',
'IPAdapterModelField',
];
/**
@ -47,7 +46,7 @@ export const MODEL_TYPES = [
export const FIELD_COLORS: { [key: string]: string } = {
BoardField: 'purple.500',
BooleanField: 'green.500',
ClipField: 'green.500',
CLIPField: 'green.500',
ColorField: 'pink.300',
ConditioningField: 'cyan.500',
ControlField: 'teal.500',
@ -67,6 +66,6 @@ export const FIELD_COLORS: { [key: string]: string } = {
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
VaeField: 'blue.500',
VaeModelField: 'teal.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};