feat(nodes): "ModelField" -> "ModelIdentifierField", add hash/name/base/type

This commit is contained in:
psychedelicious
2024-03-09 19:43:24 +11:00
parent 67d26cd633
commit 92b0d13d0e
8 changed files with 44 additions and 36 deletions

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
@ -16,33 +16,39 @@ from .baseinvocation import (
)
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class ModelIdentifierField(BaseModel):
key: str = Field(description="The model's unique key")
hash: str = Field(description="The model's BLAKE3 hash")
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
)
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
lora: ModelIdentifierField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel):
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel")
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -84,7 +90,7 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision?
@ -119,7 +125,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -190,7 +196,7 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -264,7 +270,7 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelField = InputField(
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
)