feat(nodes): make all ModelIdentifierField inputs accept connections

This commit is contained in:
psychedelicious 2024-05-17 20:15:04 +10:00
parent 85a5a7c47a
commit fe7ed72c9c
5 changed files with 21 additions and 26 deletions

View File

@ -24,7 +24,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
Input,
InputField, InputField,
OutputField, OutputField,
UIType, UIType,
@ -80,13 +79,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1") @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField( control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
) )
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField( ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.IPAdapterModel, ui_type=UIType.IPAdapterModel,
) )

View File

@ -98,14 +98,12 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
title="Main Model", title="Main Model",
tags=["model"], tags=["model"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision? # TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@ -134,12 +132,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") @invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
class LoRALoaderInvocation(BaseInvocation): class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -197,12 +195,12 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA") lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0") @invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation): class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight.""" """Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -273,13 +271,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA", title="SDXL LoRA",
tags=["lora", "model"], tags=["lora", "model"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class SDXLLoRALoaderInvocation(BaseInvocation): class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -414,12 +412,12 @@ class SDXLLoRACollectionLoader(BaseInvocation):
return output return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
class VAELoaderInvocation(BaseInvocation): class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField( vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
) )
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@ -1,4 +1,4 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import SubModelType
@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") @invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
) )
# TODO: precision? # TODO: precision?
@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model", title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"], tags=["model", "sdxl", "refiner"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
) )
# TODO: precision? # TODO: precision?

View File

@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation( @invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
) )
class T2IAdapterInvocation(BaseInvocation): class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes.""" """Collects T2I-Adapter info to pass to other nodes."""
@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField( t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.", description="The T2I-Adapter model.",
title="T2I-Adapter Model", title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.T2IAdapterModel, ui_type=UIType.T2IAdapterModel,
) )