feat(nodes): "ModelField" -> "ModelIdentifierField", add hash/name/base/type

This commit is contained in:
psychedelicious 2024-03-09 19:43:24 +11:00
parent 67d26cd633
commit 92b0d13d0e
8 changed files with 44 additions and 36 deletions

View File

@ -35,7 +35,7 @@ from invokeai.app.invocations.fields import (
WithBoard, WithBoard,
WithMetadata, WithMetadata,
) )
from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -55,7 +55,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(description="The control image") image: ImageField = Field(description="The control image")
control_model: ModelField = Field(description="The ControlNet model to use") control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -91,7 +91,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ModelField = InputField( control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
) )
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -20,8 +20,8 @@ from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.") ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.") image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -54,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = InputField( ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
input=Input.Direct, input=Input.Direct,
@ -97,7 +97,7 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=self.ip_adapter_model, ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key), image_encoder_model=ModelIdentifierField(key=image_encoder_models[0].key),
weight=self.weight, weight=self.weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,

View File

@ -76,7 +76,7 @@ from .baseinvocation import (
invocation_output, invocation_output,
) )
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelField, UNetField, VAEField from .model import ModelIdentifierField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@ -245,7 +245,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelField, scheduler_info: ModelIdentifierField,
scheduler_name: str, scheduler_name: str,
seed: int, seed: int,
) -> Scheduler: ) -> Scheduler:

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -16,33 +16,39 @@ from .baseinvocation import (
) )
class ModelField(BaseModel): class ModelIdentifierField(BaseModel):
key: str = Field(description="Key of the model") key: str = Field(description="The model's unique key")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None) hash: str = Field(description="The model's BLAKE3 hash")
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
)
class LoRAField(BaseModel): class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model") lora: ModelIdentifierField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model") weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelField = Field(description="Info to load unet submodel") unet: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel") scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class CLIPField(BaseModel): class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel") tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel") text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder") skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel): class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel") vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -84,7 +90,7 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
model: ModelField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
) )
# TODO: precision? # TODO: precision?
@ -119,7 +125,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
class LoRALoaderInvocation(BaseInvocation): class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -190,7 +196,7 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderInvocation(BaseInvocation): class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -264,7 +270,7 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
class VAELoaderInvocation(BaseInvocation): class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelField = InputField( vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
) )

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from .model import CLIPField, ModelField, UNetField, VAEField from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
@invocation_output("sdxl_model_loader_output") @invocation_output("sdxl_model_loader_output")
@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
model: ModelField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
) )
# TODO: precision? # TODO: precision?
@ -72,7 +72,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
model: ModelField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
) )
# TODO: precision? # TODO: precision?

View File

@ -10,14 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
) )
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterField(BaseModel): class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.") image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.") t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter") weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -52,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.") image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: ModelField = InputField( t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.", description="The T2I-Adapter model.",
title="T2I-Adapter Model", title="T2I-Adapter Model",
input=Input.Direct, input=Input.Direct,

View File

@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
""" """
@ -300,7 +300,7 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
def exists(self, identifier: Union[str, "ModelField"]) -> bool: def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists. """Checks if a model exists.
Args: Args:
@ -314,7 +314,9 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.store.exists(identifier.key) return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model. """Loads a model.
Args: Args:
@ -361,7 +363,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config. """Gets a model's config.
Args: Args:

View File

@ -36,7 +36,7 @@ from invokeai.app.invocations.model import (
CLIPField, CLIPField,
CLIPOutput, CLIPOutput,
LoRALoaderOutput, LoRALoaderOutput,
ModelField, ModelIdentifierField,
ModelLoaderOutput, ModelLoaderOutput,
SDXLLoRALoaderOutput, SDXLLoRALoaderOutput,
UNetField, UNetField,
@ -114,7 +114,7 @@ __all__ = [
"MetadataItemOutput", "MetadataItemOutput",
"MetadataOutput", "MetadataOutput",
# invokeai.app.invocations.model # invokeai.app.invocations.model
"ModelField", "ModelIdentifierField",
"UNetField", "UNetField",
"CLIPField", "CLIPField",
"VAEField", "VAEField",