feat(nodes): "ModelField" -> "ModelIdentifierField", add hash/name/base/type

This commit is contained in:
psychedelicious 2024-03-09 19:43:24 +11:00
parent 67d26cd633
commit 92b0d13d0e
8 changed files with 44 additions and 36 deletions

View File

@ -35,7 +35,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -55,7 +55,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelField = Field(description="The ControlNet model to use")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -91,7 +91,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelField = InputField(
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -20,8 +20,8 @@ from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -54,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = InputField(
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
@ -97,7 +97,7 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key),
image_encoder_model=ModelIdentifierField(key=image_encoder_models[0].key),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,

View File

@ -76,7 +76,7 @@ from .baseinvocation import (
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelField, UNetField, VAEField
from .model import ModelIdentifierField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -245,7 +245,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelField,
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
) -> Scheduler:

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
@ -16,33 +16,39 @@ from .baseinvocation import (
)
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class ModelIdentifierField(BaseModel):
key: str = Field(description="The model's unique key")
hash: str = Field(description="The model's BLAKE3 hash")
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
)
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
lora: ModelIdentifierField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel):
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel")
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -84,7 +90,7 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision?
@ -119,7 +125,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -190,7 +196,7 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -264,7 +270,7 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelField = InputField(
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
)

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .model import CLIPField, ModelField, UNetField, VAEField
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
@invocation_output("sdxl_model_loader_output")
@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -72,7 +72,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?

View File

@ -10,14 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -52,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: ModelField = InputField(
t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,

View File

@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
@ -300,7 +300,7 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists.
Args:
@ -314,7 +314,9 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model.
Args:
@ -361,7 +363,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.
Args:

View File

@ -36,7 +36,7 @@ from invokeai.app.invocations.model import (
CLIPField,
CLIPOutput,
LoRALoaderOutput,
ModelField,
ModelIdentifierField,
ModelLoaderOutput,
SDXLLoRALoaderOutput,
UNetField,
@ -114,7 +114,7 @@ __all__ = [
"MetadataItemOutput",
"MetadataOutput",
# invokeai.app.invocations.model
"ModelField",
"ModelIdentifierField",
"UNetField",
"CLIPField",
"VAEField",