refactor(nodes): model identifiers

- All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit.
- Update all invocation to use the new format.
- In the node API, models are loaded by key or an instance of `ModelField` as a convenience.
- Add an enriched model schema for metadata. It includes key, hash, name, base and type.
This commit is contained in:
psychedelicious
2024-03-06 19:37:15 +11:00
parent afd9ae7712
commit 528ac5dd25
15 changed files with 229 additions and 288 deletions

View File

@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer) assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel) assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw) assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
@ -133,10 +133,10 @@ class SDXLPromptInvocationBase:
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer) assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_info = context.models.load(lora.lora)
lora_model = lora_info.model lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw) assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight) yield (lora_model, lora.weight)

View File

@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
WithBoard, WithBoard,
WithMetadata, WithMetadata,
) )
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
] ]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(description="The control image") image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use") control_model: ModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct) control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
) )

View File

@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
# endregion # endregion
class MetadataField(RootModel): class MetadataField(RootModel[dict[str, Any]]):
""" """
Pydantic model for metadata with custom root of type dict[str, Any]. Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema. Metadata is stored without a strict schema.

View File

@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.") ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.") image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = InputField( ip_adapter_model: ModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1 description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
) )
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs( image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
) )
assert len(image_encoder_models) == 1 assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=self.ip_adapter_model, ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model, image_encoder_model=ModelField(key=image_encoder_models[0].key),
weight=self.weight, weight=self.weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,

View File

@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -75,7 +76,7 @@ from .baseinvocation import (
invocation_output, invocation_output,
) )
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelField, UNetField, VaeField
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
) )
if image_tensor is not None: if image_tensor is not None:
vae_info = context.models.load(**self.vae.vae.model_dump()) vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -244,12 +245,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelInfo, scheduler_info: ModelField,
scheduler_name: str, scheduler_name: str,
seed: int, seed: int,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -461,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0? # and if weight is None, populate with default 1.0?
controlnet_data = [] controlnet_data = []
for control_info in control_list: for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key)) control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
# control_models.append(control_model) # control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
@ -523,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = [] conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter: for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(key=single_ip_adapter.ip_adapter_model.key) context.models.load(single_ip_adapter.ip_adapter_model)
) )
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key) image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list): if not isinstance(single_ipa_image_fields, list):
@ -538,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model: with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel. # Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model single_ipa_images, image_encoder_model
@ -577,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = [] t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter: for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key) t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name) image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -731,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.models.load(**self.unet.unet.model_dump()) unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
@ -841,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump()) vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device) latents = latents.to(vae.device)
@ -1064,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(**self.vae.vae.model_dump()) vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3: if image_tensor.dim() == 3:

View File

@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
OutputField, OutputField,
UIType, UIType,
) )
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from ...version import __version__ from ...version import __version__
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
value: Any = Field(description=FieldDescriptions.metadata_item_value) value: Any = Field(description=FieldDescriptions.metadata_item_value)
class ModelMetadataField(BaseModel):
"""Model Metadata Field"""
key: str
hash: str
name: str
base: BaseModelType
type: ModelType
class LoRAMetadataField(BaseModel): class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field""" """LoRA Metadata Field"""
model: LoRAModelField = Field(description=FieldDescriptions.lora_model) model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight) weight: float = Field(description=FieldDescriptions.lora_weight)
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model""" """IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field( ip_adapter_model: ModelMetadataField = Field(
description="The IP-Adapter model.", description="The IP-Adapter model.",
) )
weight: Union[float, list[float]] = Field( weight: Union[float, list[float]] = Field(
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output") @invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput): class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output""" """Metadata Item Output"""
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None, default=None,
description="The number of skipped CLIP layers", description="The number of skipped CLIP layers",
) )
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference") model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlField]] = InputField( controlnets: Optional[list[ControlNetMetadataField]] = InputField(
default=None, description="The ControlNets used for inference" default=None, description="The ControlNets used for inference"
) )
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField( ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference" default=None, description="The IP Adapters used for inference"
) )
t2iAdapters: Optional[list[T2IAdapterField]] = InputField( t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference" default=None, description="The IP Adapters used for inference"
) )
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference") loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None, default=None,
description="The name of the initial image", description="The name of the initial image",
) )
vae: Optional[VAEModelField] = InputField( vae: Optional[ModelMetadataField] = InputField(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
) )
# SDXL Refiner # SDXL Refiner
refiner_model: Optional[MainModelField] = InputField( refiner_model: Optional[ModelMetadataField] = InputField(
default=None, default=None,
description="The SDXL Refiner model used", description="The SDXL Refiner model used",
) )
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput: def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object""" """Collects and outputs a CoreMetadata object"""
return MetadataOutput( as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
metadata=MetadataField.model_validate( as_dict["app_version"] = __version__
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
) return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
)
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -16,33 +16,34 @@ from .baseinvocation import (
) )
class ModelInfo(BaseModel): class ModelField(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class LoraInfo(ModelInfo): class LoRAField(BaseModel):
weight: float = Field(description="Lora's weight which to use when apply to model") lora: ModelField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelField = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoRAField] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel): class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder") skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoRAField] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -74,18 +75,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation( @invocation(
"main_model_loader", "main_model_loader",
title="Main Model", title="Main Model",
@ -96,46 +85,24 @@ class LoRAModelField(BaseModel):
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision? # TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.models.exists(key): if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {key}") raise Exception(f"Unknown model {self.model.key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
unet=ModelInfo( clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
key=key, vae=VaeField(vae=vae),
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.VAE,
),
),
) )
@ -151,7 +118,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
class LoraLoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
default=None, default=None,
@ -167,38 +134,33 @@ class LoraLoaderInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key lora_key = self.lora.key
if not context.models.exists(lora_key): if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!") raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet') raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip') raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoraLoaderOutput() output = LoraLoaderOutput()
if self.unet is not None: if self.unet is not None:
output.unet = copy.deepcopy(self.unet) output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoRAField(
key=lora_key, lora=self.lora,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
if self.clip is not None: if self.clip is not None:
output.clip = copy.deepcopy(self.clip) output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoRAField(
key=lora_key, lora=self.lora,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@ -225,7 +187,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
class SDXLLoraLoaderInvocation(BaseInvocation): class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
default=None, default=None,
@ -247,51 +209,45 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key lora_key = self.lora.key
if not context.models.exists(lora_key): if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!") raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet') raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip') raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2') raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput() output = SDXLLoraLoaderOutput()
if self.unet is not None: if self.unet is not None:
output.unet = copy.deepcopy(self.unet) output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoRAField(
key=lora_key, lora=self.lora,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
if self.clip is not None: if self.clip is not None:
output.clip = copy.deepcopy(self.clip) output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoRAField(
key=lora_key, lora=self.lora,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
if self.clip2 is not None: if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2) output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append( output.clip2.loras.append(
LoraInfo( LoRAField(
key=lora_key, lora=self.lora,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@ -299,17 +255,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
return output return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation): class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: VAEModelField = InputField( vae_model: ModelField = InputField(
description=FieldDescriptions.vae_model, description=FieldDescriptions.vae_model,
input=Input.Direct, input=Input.Direct,
title="VAE", title="VAE",
@ -321,7 +271,7 @@ class VaeLoaderInvocation(BaseInvocation):
if not context.models.exists(key): if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!") raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) return VAEOutput(vae=VaeField(vae=self.vae_model))
@invocation_output("seamless_output") @invocation_output("seamless_output")

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField from .model import ClipField, ModelField, UNetField, VaeField
@invocation_output("sdxl_model_loader_output") @invocation_output("sdxl_model_loader_output")
@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
model: MainModelField = InputField( model: ModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
) )
# TODO: precision? # TODO: precision?
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput( return SDXLModelLoaderOutput(
unet=UNetField( unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
unet=ModelInfo( clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
key=model_key, clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
submodel_type=SubModelType.UNet, vae=VaeField(vae=vae),
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
) )
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
model: MainModelField = InputField( model: ModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
) )
# TODO: precision? # TODO: precision?
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput( return SDXLRefinerModelLoaderOutput(
unet=UNetField( unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
unet=ModelInfo( clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
key=model_key, vae=VaeField(vae=vae),
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
) )

View File

@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
) )
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel): class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.") image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.") t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter") weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.") image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = InputField( t2i_adapter_model: ModelField = InputField(
description="The T2I-Adapter model.", description="The T2I-Adapter model.",
title="T2I-Adapter Model", title="T2I-Adapter Model",
input=Input.Direct, input=Input.Direct,

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
import torch import torch
from typing_extensions import Self from typing_extensions import Self

View File

@ -130,6 +130,17 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@abstractmethod
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param hash: Hash of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod @abstractmethod
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default

View File

@ -203,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
""" """
Return True if a model with the indicated key exists in the databse. Return True if a model with the indicated key exists in the databse.

View File

@ -1,7 +1,7 @@
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
from PIL.Image import Image from PIL.Image import Image
from torch import Tensor from torch import Tensor
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
""" """
@ -299,22 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
def exists(self, key: str) -> bool: def exists(self, identifier: Union[str, "ModelField"]) -> bool:
"""Checks if a model exists. """Checks if a model exists.
Args: Args:
key: The key of the model. identifier: The key or ModelField representing the model.
Returns: Returns:
True if the model exists, False if not. True if the model exists, False if not.
""" """
return self._services.model_manager.store.exists(key) if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model. """Loads a model.
Args: Args:
key: The key of the model. identifier: The key or ModelField representing the model.
submodel_type: The submodel of the model to get. submodel_type: The submodel of the model to get.
Returns: Returns:
@ -324,9 +328,13 @@ class ModelsInterface(InvocationContextInterface):
# The model manager emits events as it loads the model. It needs the context data to build # The model manager emits events as it loads the model. It needs the context data to build
# the event payloads. # the event payloads.
return self._services.model_manager.load_model_by_key( if isinstance(identifier, str):
key=key, submodel_type=submodel_type, context_data=self._data model = self._services.model_manager.store.get_model(identifier)
) return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -343,35 +351,29 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
An object representing the loaded model. An object representing the loaded model.
""" """
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
def get_config(self, key: str) -> AnyModelConfig: configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
"""Gets a model's config. """Gets a model's config.
Args: Args:
key: The key of the model. identifier: The key or ModelField representing the model.
Returns: Returns:
The model's config. The model's config.
""" """
return self._services.model_manager.store.get_model(key=key) if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: return self._services.model_manager.store.get_model(identifier.key)
"""Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]: def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path. """Searches for models by path.

View File

@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt): for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1] name_or_key = trigger[1:-1]
try: try:
loaded_model = context.models.load(key=name_or_key) loaded_model = context.models.load(name_or_key)
model = loaded_model.model model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw) assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base assert loaded_model.config.base == base

View File

@ -35,17 +35,13 @@ from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOut
from invokeai.app.invocations.model import ( from invokeai.app.invocations.model import (
ClipField, ClipField,
CLIPOutput, CLIPOutput,
LoraInfo,
LoraLoaderOutput, LoraLoaderOutput,
LoRAModelField, ModelField,
MainModelField,
ModelInfo,
ModelLoaderOutput, ModelLoaderOutput,
SDXLLoraLoaderOutput, SDXLLoraLoaderOutput,
UNetField, UNetField,
UNetOutput, UNetOutput,
VaeField, VaeField,
VAEModelField,
VAEOutput, VAEOutput,
) )
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import (
@ -73,8 +69,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
@ -118,14 +114,10 @@ __all__ = [
"MetadataItemOutput", "MetadataItemOutput",
"MetadataOutput", "MetadataOutput",
# invokeai.app.invocations.model # invokeai.app.invocations.model
"ModelInfo", "ModelField",
"LoraInfo",
"UNetField", "UNetField",
"ClipField", "ClipField",
"VaeField", "VaeField",
"MainModelField",
"LoRAModelField",
"VAEModelField",
"UNetOutput", "UNetOutput",
"VAEOutput", "VAEOutput",
"CLIPOutput", "CLIPOutput",
@ -166,7 +158,7 @@ __all__ = [
# invokeai.app.services.config.config_default # invokeai.app.services.config.config_default
"InvokeAIAppConfig", "InvokeAIAppConfig",
# invokeai.backend.model_management.model_manager # invokeai.backend.model_management.model_manager
"LoadedModelInfo", "LoadedModel",
# invokeai.backend.model_management.models.base # invokeai.backend.model_management.models.base
"BaseModelType", "BaseModelType",
"ModelType", "ModelType",