refactor(nodes): model identifiers

- All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit.
- Update all invocation to use the new format.
- In the node API, models are loaded by key or an instance of `ModelField` as a convenience.
- Add an enriched model schema for metadata. It includes key, hash, name, base and type.
This commit is contained in:
psychedelicious
2024-03-06 19:37:15 +11:00
parent afd9ae7712
commit 528ac5dd25
15 changed files with 229 additions and 288 deletions

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -16,33 +16,34 @@ from .baseinvocation import (
)
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
loras: List[LoRAField] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
vae: ModelField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -74,18 +75,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation(
"main_model_loader",
title="Main Model",
@ -96,46 +85,24 @@ class LoRAModelField(BaseModel):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
key = self.model.key
# TODO: not found exceptions
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
vae=VaeField(vae=vae),
)
@ -151,7 +118,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -167,38 +134,33 @@ class LoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -225,7 +187,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -247,51 +209,45 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -299,17 +255,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: VAEModelField = InputField(
vae_model: ModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
@ -321,7 +271,7 @@ class VaeLoaderInvocation(BaseInvocation):
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
return VAEOutput(vae=VaeField(vae=self.vae_model))
@invocation_output("seamless_output")